模型训练过程在Ubuntu(22.04.3) Anaconda(23.1.0) Pytorch(1.8.0+cu111)环境下进行。
1.1 环境搭建
- 具体的安装方法如下连接所示,主要部署基于Pytorch(1.8.0+cu111)的虚拟环境用于模型的搭建和训练。Ubuntu22.04上安装Anaconda3-CSDN博客
- 同时,需要安装一下自己常用的库,因为做模型的训练,我这里为了方便起见,还需要安装一下timm库。
pip install timm
1.2 数据集准备
- 使用ImageFolder对数据进行加载
train_dataset = datasets.ImageFolder(root=args.TrainFolder, transform=data_train_transform)
test_dataset = datasets.ImageFolder(root=args.TestFolder, transform=data_test_transform)
print(train_dataset.class_to_idx)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True,num_workers=4,drop_last=True,pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False,num_workers=4,pin_memory=True)
- data_train_transform与data_test_transform数据预处理方法为,此处不写正则化,是想着简化后面在C#中的预处理步骤(统一除以255即可)
data_train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1,0.3)),
transforms.RandomResizedCrop((224,224)),
# transforms.Resize((224,224)),
#Cutout(),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.49404827, 0.49592876, 0.4070973],
# std=[0.2071776, 0.2007362, 0.21388549])
])
data_test_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.5006999, 0.50564367, 0.40601832],
# std=[0.21505603, 0.2077407, 0.2244142])
])
1.3 训练过程搭建
- 我们直接使用timm库选择模型结构,因为后面C#选择CPU对模型进行推理,所以选择参数量和计算复杂度低的模型,这块我选择的是mobilevit。
model = timm.create_model('timm/mobilevit_xxs.cvnets_in1k',num_classes=args.num_classes,pretrained=True ,pretrained_cfg_over