1、数据集的准备
由于是分类模型,所以数据集需要准备至少两个类别,这里假设为:A、B两类。

源码中是假设数据集分为两部分:train、val,分别对应训练、验证集;在train或者val数据集中,每类样本分别存放于一个独立文件夹,也即:A类的样本存放于文件夹A中,B类的样本存放于文件夹B中;因此,总的数据集为:train、val中分别有A、B两个文件夹。

如果你的数据集是按照这种格式存储的,就可以直接使用源码进行训练了,因为源码中的dataset是使用的torch中的自带数据集定义:

import torchvision.datasets as datasets    
train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

2、模型的训练
数据准备完毕并修改对应dataset对象之后,需要对训练代码微调下。

需要注意的一点是,源码是使用ImageNet数据集进行训练的,有个全局变量IMAGENET_TRAINSET_SIZE代表的是数据集大小,需要改为你的数据集大小,不然影响学习率调度函数的使用,可以将其定义放到DataLoader之后,然后令其等于len(train_loader)即可。

其他的,就是那些超参了,修改一下即可训练。

这一步没啥好说的,轻微修改一下,训练即可

github

图片说明