用于人脸口罩检测的YOLOv5模型训练
fullstacker 发布于 2021-02-03

在本文中,我们将使用数据集来训练YOLOv5对象检测模型。在这里,我们将训练和测试我们的YOLOv5面具检测模型。

在本系列的上一篇文章中,我们标记了一个面罩数据集。现在是这个项目最激动人心的部分——模特培训的时候了。

介绍

在本系列的上一篇文章中,我们标记了一个面罩数据集。现在是这个项目最激动人心的部分——模特培训的时候了。

准备培训和验证数据

就像任何其他模型一样,YOLOv5需要一些验证数据来确定在训练期间和训练之后推断的好坏。这就是为什么我们需要将图像集拆分为train和val数据集及其相应的.txt文件。通常情况下,培训的分割率为80%,验证的分割率为20%,这些分割率必须按以下方式分配:

 Dataset
|
| -- images
|      | -- train
|      | -- val
|
| -- labels
|      | -- train
|      | -- val
|
| -- dataset.yaml
我手动分割数据;我将前3829个图像放在images/train目录中,其余的955个放在images/val目录中。然后,我将相应的.txt文件复制到labels/train和labels/val。

这个数据集.yaml文件告诉模型数据集是如何分布的、有多少个类以及它们的名称。在我们的案例中,文件是这样的:

                                      

功能包括:

第1行表示列车图像集的相对路径

第2行指示验证集的相对路径

第6行说明数据集包含多少类

第9行定义了每个类的名称

每次训练YOLOv5模型时,都必须手动创建此文件。幸运的是,这不是一项复杂的任务。

在Colab笔记本上训练YOLOv5模型

我们不会将您的本地计算机推向极限,因为有多种云计算选项。让我们使用一个googlecolab笔记本,一个非常强大和易于使用的解决方案。

在继续之前,我有一些重要的事情要提。YOLOv5模型运行在PyTorch之上,PyTorch是一个ML框架,需要太多的计算资源才能在小型设备上运行。有一个解决办法-忍受我。

YOLOv5和其他YOLO版本都是由Ultralytics开发的,Ultralytics维护了一个Git repo,您可以从中获得使用这些模型所需的所有文件。尽管它是官方的存储库,但仍然缺少一些重要的改进。一些重要的功能将包括在2021年初。存储库将包含将PyTorch自定义模型转换为TensorFlow和TensorFlow-Lite兼容版本所需的函数。最后一个解决方案是在CPU/内存受限的设备上讨论ML时的最佳解决方案。

上述改进是由Ultralytics的一个重要贡献者开发的,可以在这个存储库中找到。这是我们将在这个项目中使用的。之所以采用这种方法,是因为在Ultralytics管道中,当前的PyTorch–TensorFlow Lite转变没有明确定义。您需要手动将.pt文件转换为.onnx,然后获取TensorFlow权重,最终将其转换为TensorFlow Lite权重。在过程的中间很容易出问题,从而使生成的模型不稳定。

好吧,我们开始吧!喝杯咖啡,启动你的Google Colab笔记本(我的在这里)。

让我们从设置GPU实例类型开始。

首先,在屏幕顶部的菜单中,选择编辑>笔记本设置>GPU,并确保“硬件加速器”设置为GPU。

接下来,选择Runtime>Change Runtime Type>Hardware accelerator>GPU。必须按以下方式选择两个下拉列表:



你的笔记本将需要一些时间来初始化-给它几分钟

笔记本启动并运行后,设置初始配置:

 import torch # Keep in mind that YOLOv5 runs on top of PyTorch, so we need to import it to the notebook

from IPython.display import Image #this is to render predictions

#!git clone https://github.com/ultralytics/yolov5 # Use this if you want to keep the official Ultralytics scripts.

!git clone https://github.com/zldrobit/yolov5.git
%cd yolov5
!git checkout tf-android
以上几行将导入基本库,克隆包含模型的YOLOv5 repo,以及PyTorch TensorFlow Lite转换到/content目录所需的关键文件。

下一步是安装一些必需的组件并更新当前笔记本的TensorFlow版本:

!pip install -r requirements.txt
!pip install tensorflow==2.3.1 #Keep this version of TF as YOLOv5 works well with it.
print('All set. Using PyTorch version %s with %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))
让我们将我们的面罩数据集导入笔记本。我已经把它上传到github了,所以运行下面几行:

%cd /content
!git clone https://github.com/sergiovirahonda/FaceMaskDataset
克隆完所有文件后,需要移动数据集的数据集.yaml文件到/content/yolov5/data目录。使用左侧的文件浏览器,将文件从/content/FaceMaskDataset手动拖动到/content/yolov5/data。如前所述,此文件包含YOLO在自定义数据上训练模型所需的信息。如果要检查文件,请运行以下操作:

%cd /content
!git clone https://github.com/sergiovirahonda/FaceMaskDataset
现在是在自定义数据上训练模型的时候了。首先,导航到/yolov5目录,其中train.py文件位于该目录下。以下是实现脚本时要记住的一些注意事项:

您必须以像素表示图像大小(在本例中为415)。由于415不是预定义的最大跨距32的倍数,因此模型将每个图像跨距为416像素。

指明批量大小。在这种情况下,我们将保持在16。

你必须指出时代的数目。历元数越大,置信度越高。在某个时刻,过度拟合是意料之中的,而且,如果数据集相对较小,这种情况将很快发生。现在30个时代已经足够了,否则你会注意到一些过度拟合。

通过--data参数,可以引用包含数据集配置的.yaml文件。

最后,-nosave只用于保存最后的检查点,-cache允许管道缓存图像以减少训练时间。我想保留最好的度量,所以我不使用--nosave。

在训练过程中,要想知道比赛进行得有多顺利,就要密切关注比赛地图@.5公制(平均精度)。如果它接近1,那么模型将获得很好的结果!

要训练模型,请运行以下命令:

%cd /yolov5
!python train.py --img 415 --batch 16 --epochs 30 --data dataset.yaml --weights yolov5s.pt --cache

上面的行对数据集执行基本检查,缓存图像,并覆盖其他配置。他们输出一个模型架构概要(检查它以了解模型由什么组成),然后开始培训。

培训结束时,两个文件应保存在/content/yolov5/runs/train/exp/weights中:最后.pt以及最佳.pt. 我们将使用最佳.pt.

如果您想探索培训期间记录的指标,我建议您使用TensorBoard,这是一个非常交互式的探索工具:

 %load_ext tensorboard
%tensorboard --logdir runs
这将加载如下内容:


注意地图和精度如何在25个时代后达到最大值。这就是为什么训练你的模特超过30个时代是没有意义的。

在googlecolab上测试模型

现在让我们看看我们的模型有多自信。我们可以绘制培训期间获得的验证批次,并检查每个标签的置信度得分:

Image(filename='runs/train/exp/test_batch0_pred.jpg', width=1000)
将绘制以下内容:


虽然0.8不是一个很好的分数,但在这类车型中还是不错的。记住,YOLOv5牺牲了检测速度的准确性。

我们现在有一个训练有素的模型。要在新的、不可见的图像上测试它,请手动创建一个新目录(例如,/content/test\u images),上载一些图像,然后运行以下代码:

 %load_ext tensorboard
%tensorboard --logdir runs
你将实施detect.py使用脚本最佳.pt416x416像素的权重和图像尺寸(遵守这一点非常重要)。结果将保存到runs/detect/exp。要显示结果,请运行以下代码:

#Plotting the first image
Image(filename='runs/detect/exp/testimage1.jpg', width=415)
结果:



全栈者
关注 私信
文章
31
关注
0
粉丝
0