目录
一、项目界面
二、代码实现
1、数据集结构
每一个文件夹对应一个类别的数据
2、设置需要模型的训练参数和指定数据集路径
# 数据名字标签
label_names = {0:"daisy",
1:"dandelion",
2:"rose",
3:"sunflower",
4:"tulip",
}
# 类别数量,根据label_names标签名自动得出
num_classes = len(label_names)
# 重采样大小。如果无则填None
re_size = (28,28)
# 训练集地址,默认即可
train_path = r"./data/train"
# 验证集地址,默认即可
val_path = r"./data/val"
# 测试集地址,默认即可
test_path = r"./data/test"
# 图像后缀
img_ = "jpg"
# 批量大小
batch_size = 64
# 结果保存地址
save_results = r"./results"
# 学习率
lr = 0.001
# 迭代次数
epochs = 20
# ----------划分数据集参数-----------
# 确定将数据集划分为训练集,验证集,测试集的比例
train_pct = 0.5
valid_pct = 0.1
test_pct = 0.4
# 确定原图像数据集路径。默认即可
dataset_dir = r"./data/data" # 原始数据集路径
# 确定数据集划分后保存的路径
split_dir = r"./data" # 划分后保存路径
3、网络代码
该网络基于残差模型修改
import torch
import torch.nn as nn
import torchvision.models as models
class resnet18(nn.Module):
def __init__(self, num_classes=5, pretrained=False):
super(resnet18, self).__init__()
# 加载ResNet-18模型
self.model = models.resnet18(pretrained=pretrained)
# print(self.model)
# 更改全连接