基于Pytorch的ResNet垃圾图片分类
1. 数据集预处理
1.1 画图片的宽高分布散点图
import os
import matplotlib.pyplot as plt
import PIL.Image as Image
def plot_resolution(dataset_root_path):
image_size_list = []#存放图片尺寸
for root, dirs, files in os.walk(dataset_root_path):
for file in files:
image_full_path = os.path.join(root, file)
image = Image.open(image_full_path)
image_size = image.size
image_size_list.append(image_size)
print(image_size_list)
image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽
image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高
plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题
plt.scatter(image_width_list, image_height_list, s=1)
plt.xlabel('宽')
plt.ylabel('高')
plt.title('图像宽高分布散点图')
plt.show()
if __name__ == '__main__':
dataset_root_path = "F:\数据与代码\dataset"
plot_resolution(dataset_root_path)
运行结果:
1.2 画出数据集的各个类别图片数量的条形图
文件组织结构:
def plot_bar(dataset_root_path):
file_name_list = []
file_num_list = []
for root, dirs, files in os.walk(dataset_root_path):
if len(dirs) != 0 :
for dir in dirs:
file_name_list.append(dir)
file_num_list.append(len(files))
file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]
#[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]
mean = np.mean(file_num_list)
print("mean= ", mean)
bar_positions = np.arange(len(file_name_list))
fig, ax = plt.subplots()
ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度
ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['font.size'] = 8
plt.rcParams['axes.unicode_minus'] = False # 解决图像中的负号乱码问题
ax.set_xticks(bar_positions)#设置x轴的刻度
ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签
ax.set_ylabel("类别数量")
ax.set_title("各个类别数量分布散点图")
plt.show()
运行结果
1.3 删除宽高有问题的图片
import os
import PIL.Image as Image
MIN = 200
MAX = 2000
ratio = 0.5
def delete_img(dataset_root_path):
delete_img_list = [] #需要删除的图片地址
for root, dirs, files in os.walk(dataset_root_path):
for file in files:
img_full_path = os.path.join(root, file)
img = Image.open(img_full_path)
img_size = img.size
max_l = img_size[0] if img_size[0] > img_size[1] else img_size[1]
min_l = img_size[0] if img_size[0] < img_size[1] else img_size[1]
# 把图片宽高限制在 200~2000 这里可能会重复添加图片路径
if img_size[0] < MIN or img_size[1] < MIN:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
elif img_size[0] > MAX or img_size[1] > MAX:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
#避免图片窄长
elif min_l / max_l < ratio:
delete_img_list.append(img_full_path)
print("不满足要求", img_full_path, img_size)
for img in delete_img_list:
print("正在删除", img)
os.remove(img)
if __name__ == '__main__':
dataset_root_img = 'F:\数据与代码\dataset'
delete_img(dataset_root_img)
再次运行1.1 和1.2的代码得到处理后的数据集宽高分布和类别数量
1.4 数据增强
import os
import cv2
#水平翻转
import numpy as np
def Horizontal(image):
return cv2.flip(image, 1, dst=None)
#垂直翻转
def Vertical(image):
return cv2.flip(image, 0, dst=None)
threshold = 200 #阈值
#数据增强
def data_augmentation(from_root_path, save_root_path):
for root, dirs, files in os.walk(from_root_path):
for file in files:
img_full_path = os.path.join(root, file)
split = os.path.split(img_full_path)
save_path = os.path.join(save_root_path, os.path.split(split[0])[1])
print(save_path)
if os.path.isdir(save_path) == False:#文件夹不存在就创建
os.makedirs(save_path)
img = cv2.imdecode(np.fromfile(img_full_path, dtype=np.uint8), -1)#读取含中文的路径
cv2.imencode('.jpg', img)[1].tofile(os.path.join(save_path,file[:-5]+ "_original.jpg")) #保存原图
if len(files) > 0 and len(files) < threshold: # 类别数量小于阈值,需要对该类别的所有图片进行数据增强
img_horizontal = Horizontal(img)
cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file[:-5] + "_horizontal.jpg"))
img_vertical = Vertical(img)
cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file[:-5] + "_vertical.jpg"))
else:
pass
if __name__ == '__main__':
from_root_path = 'F:\数据与代码\dataset'
save_root_path = 'F:\数据与代码\enhance_dataset'
data_augmentation(from_root_path, save_root_path)
1.5 数据集平衡处理
将图片数量超过阈值的类别删除一部分图片
import os
import random
threshold = 300
def dataset_balance(dataset_root_path):
for root, dirs, files in os.walk(dataset_root_path):
if len(files) > threshold:
delete_img_list = []
for file in files:
img_full_path = os.path.join(root, file)
delete_img_list.append(img_full_path)
random.shuffle(delete_img_list)
delete_img_list = delete_img_list[threshold:]
for img in delete_img_list:
os.remove(img)
print("成功删除", img)
if __name__ == '__main__':
dataset_root_path = 'F:\数据与代码\enhance_dataset'
dataset_balance(dataset_root_path)
1.6 求图像的均值和方差
from torchvision import transforms as T
import torch
from torchvision.datasets import ImageFolder
from tqdm import tqdm
transform = T.Compose([
T.RandomResizedCrop(224),#随机采样并缩放为 224X224
T.ToTensor(),
])
def getStat(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)
#均值 方差
mean = torch.zeros(3)#三维
std = torch.zeros(3)
for X, _ in tqdm(train_loader):# tqdm添加进度条
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root='F:/数据与代码/enhance_dataset', transform=transform)
print(getStat(train_dataset))
2. 生成数据集与数据加载器
2.1 生成数据集
import os
import random
train_ratio = 0.9
test_ratio = 1 - train_ratio
root_data = 'F:\数据与代码\enhance_dataset'
train_list, test_list = [], []
class_flag = -1
for root, dirs, files in os.walk(root_data):
for i in range(0, int(len(files)*train_ratio)):
train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
train_list.append(train_data)
for i in range(int(len(files)*train_ratio), len(files)):
test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
test_list.append(test_data)
class_flag += 1
random.shuffle(train_list)
random.shuffle(test_list)
with open('train.txt', 'w', encoding='UTF-8') as f:
for train_img in train_list:
f.write(str(train_img))
with open('test.txt', 'w', encoding='UTF-8') as f:
for test_img in test_list:
f.write(str(test_img))
2.2 生成数据加载器
import torch
from PIL import Image
import torchvision.transforms as transforms
#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.utils.data import Dataset
#数据归一化与标准化
transform_BZ = transforms.Normalize(
mean = [0.64148515, 0.57362735, 0.5084857],
std = [0.21153161, 0.21981773, 0.22988321]
)
class LoadData(Dataset):
def __init__(self, txt_path, train_flag=True):
self.imgs_info = self.get_images(txt_path)
self.train_flag = train_flag
self.img_size = 512
self.train_tf = transforms.Compose([
transforms.Resize(self.img_size),
transforms.RandomHorizontalFlip(),#随机水平翻转
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transform_BZ#数据归一化与标准化
])
self.val_tf = transforms.Compose([
transforms.Resize(self.img_size),
transforms.ToTensor(),
transform_BZ # 数据归一化与标准化
])
def get_images(self, txt_path):#返回格式[路径, 标签]
with open(txt_path, 'r', encoding='utf-8') as f:
imgs_info = f.readlines()
#map(函数,参数)
imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
return imgs_info
def padding_black(self, img): # 如果尺寸太小可以扩充
w, h = img.size
scale = self.img_size / max(w, h)
img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
size_fg = img_fg.size
size_bg = self.img_size
img_bg = Image.new("RGB", (size_bg, size_bg))
img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
(size_bg - size_fg[1]) // 2))
img = img_bg
return img
def __getitem__(self, index):
img_path, label = self.imgs_info[index]
img = Image.open(img_path)
img = img.convert('RGB')#转换为RGB格式
img = self.padding_black(img)
if self.train_flag:
img = self.train_tf(img)
else:
img = self.val_tf(img)
label = int(label)
return img, label
def __len__(self):
return len(self.imgs_info)
if __name__ == '__main__':
train_dataset = LoadData('train.txt', True)
print("数据个数", len(train_dataset))
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=5,
shuffle=True
)
for image, label in train_loader:
print("image.shape", image.shape)
# print(image)
print(label)
3. 模型搭建与训练
# -*- coding = utf-8 -*-
# @Time : 2024-02-28 15:15
# @Author : 宋俊霖
# @File : 搭建模型与训练函数.py
# @Software : PyCharm
import time
from tqdm import tqdm
import torch
from torchvision.models import resnet18
from 生成数据加载器 import LoadData
#搭建模型
model = resnet18(num_classes=55)#55种分类
#训练函数
def train(dataloader, model, loss_fn, optimizer, device):
size = len(dataloader.dataset) #样本数
avg_loss = 0 #初始化平均损失
for batch, (X, y) in tqdm(enumerate(dataloader)): #batch: 序号,代表第几个batch X:图片 y:标签
X, y = X.to(device), y.to(device)
pred = model(X) #预测值
loss = loss_fn(pred, y)#计算每一个batch的 真实标签 和 预测标签 之间的损失
avg_loss += loss #avg_loss将每一个batch的loss累加起来
optimizer.zero_grad() #优化器清零
loss.backward() #反向传播更新模型参数
optimizer.step() #优化器更新参数
#每10个batch输出一次
if batch % 10 == 0:
loss, current = loss.item(), batch * len(X) # loss: 当前的这个batch的loss current:已经处理了多少张图片
print(f"loss:{loss:>7f} [{current:>5d} / {size:>5d}]")
avg_loss /= size #得到每张图片的平均损失
avg_loss = avg_loss.detach().cpu().numpy() # detach():去除梯度信息 cpu():把数据从显卡传回cpu
return avg_loss
#验证函数
def validate(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
model.eval() #把模型转变为验证模式,不用反向传播
avg_loss, correct = 0, 0 #corrct:正确预测的图片数量
with torch.no_grad(): #在进行模型参数计算时,不求梯度值
for X, y in tqdm(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
avg_loss += loss_fn(pred, y).item() #item():提取数值
correct += (pred.argmax(1) == y).type(torch.float).sum().item() #argmax(1):求每一行最大值的索引 True:1 False:0
avg_loss /= size
acc = correct / size #正确率
print(f"correct={correct}, error={(size - correct)}, Accuracy:{(100 * acc):>0.2f}%, Val_loss:{avg_loss:>8f} \n")
return acc, avg_loss
#数据加载器
batch_size = 32
train_data = LoadData("train.txt", True)
val_data = LoadData("test.txt", False)
train_dataloader = torch.utils.data.DataLoader(
dataset=train_data,
num_workers=4,
pin_memory=True,
batch_size=batch_size,
shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_data,
num_workers=4,
pin_memory=True,
batch_size=batch_size,
)
#损失函数
loss_fn = torch.nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-3 #学习率
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
def WriteData(fname, *args):
with open(fname, 'a+') as f:
for data in args:
f.write(str(data)+"\t")
f.write("\n")
if __name__ == '__main__':
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(f"正在使用 {device} device")
model = model.to(device)
epochs = 50
loss_ = 10 #判断当前训练的模型是否最优
save_root = "output/"
for epoch in range(epochs):
print(f"Epoch {epoch + 1}---------------------\n")
time_start = time.time()
print("开始训练")
avg_loss = train(train_dataloader, model, loss_fn, optimizer, device)
time_end = time.time()
print(f"train time: {(time_end - time_start)}")
#开始验证
print("开始验证")
val_acc, val_loss = validate(val_dataloader, model, loss_fn, device)
WriteData(
save_root + "resnet18_no_pretrain.txt",
"epoch", epoch,
"train_loss", avg_loss,
"val_loss", val_loss,
"val_acc", val_acc
)
if epoch % 5 == 0:
torch.save(model.state_dict(), save_root +
"resnet18_no_pretrain_epoch" +str(epoch)+"_train_loss_"+str(avg_loss)+".pth")
torch.save(model.state_dict(), save_root+"resnet18_no_pretrain_last.pth")
if avg_loss < loss_: #训练loss小于 loss_ 就认为当前训练模型最优
loss_ = avg_loss
torch.save(model.state_dict(), save_root+"resnet18_no_pretrain_best.pth")
4. 模型测试
4.1 单张图片模型预测
# -*- coding = utf-8 -*-
# @Time : 2024-02-28 19:25
# @Author : 宋俊霖
# @File : 单张图片模型预测.py
# @Software : PyCharm
import os
import torchvision.transforms as transforms
from PIL import Image
#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torchvision.models import resnet18
#数据归一化与标准化
transform_BZ = transforms.Normalize(
mean = [0.64148515, 0.57362735, 0.5084857],
std = [0.21153161, 0.21981773, 0.22988321]
)
def padding_black(img, img_size = 512): # 如果尺寸太小可以扩充
w, h = img.size
scale = img_size / max(w, h)
img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
size_fg = img_fg.size
size_bg = img_size
img_bg = Image.new("RGB", (size_bg, size_bg))
img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
(size_bg - size_fg[1]) // 2))
img = img_bg
return img
if __name__ == '__main__':
# img_path = 'test_dataset/img_骨肉相连_8.jpeg'
# img_path = 'test_dataset/img_电池_20.jpeg'
# img_path = 'test_dataset/img_火龙果_5.jpeg'
# img_path = 'test_dataset/img_口罩_10.jpeg'
img_path = 'test_dataset/草莓.png'
img_size = 512
test_tf = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transform_BZ # 数据归一化与标准化
])
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(f"正在使用 {device} device")
model = resnet18(num_classes=55).to(device)
state_dict = torch.load("output/resnet18_no_pretrain_best.pth")
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
img = Image.open(img_path).convert('RGB')
img = padding_black(img)
img = test_tf(img)
img_tensor = torch.unsqueeze(img, 0) #将C,H,W -> N,C,H,W
img_tensor = img_tensor.to(device)
res = model(img_tensor)
id = res.argmax(1).item()
for root, dirs, files in os.walk("enhance_dataset"):
if len(dirs) != 0:
print("预测结果是: ", dirs[id])
4.2 在测试集上预测
import os
import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from tqdm import tqdm
import pandas as pd
from 生成数据加载器 import LoadData
def test(dataloader, model, device):
pred_list = []
model.eval()
with torch.no_grad():
for X, y in tqdm(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
pred_softmax = torch.softmax(pred, 1).cpu().numpy()
pred_list.append(pred_softmax.tolist()[0])
return pred_list
def WriteData(fname, *args):
with open(fname, 'a+') as f:
for data in args:
f.write(str(data)+"\t")
f.write("\n")
if __name__ == '__main__':
batch_size = 1
test_data = LoadData("test.txt", False)
test_dataloader = DataLoader(
dataset=test_data,
num_workers=4,
pin_memory=True,
batch_size=batch_size
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:{device}")
model = resnet18(num_classes=55)
model.load_state_dict(torch.load("output/resnet18_pretrain_best.pth"))
model.to(device)
pred_list = test(test_dataloader, model, device)
print("pred_list", pred_list)
file_name_list = []
data_root = "enhance_dataset"
for root, dirs, files in os.walk(data_root):
if len(dirs) != 0:
file_name_list = dirs
df_pred = pd.DataFrame(data=pred_list, columns=file_name_list)
df_pred.to_csv('pred_result.csv', encoding='gbk', index=False)
4.3 计算精度、查准率、召回率、F1-score并绘制混淆矩阵
import os
import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from tqdm import tqdm
import pandas as pd
from 生成数据加载器 import LoadData
from sklearn.metrics import * #pip install scikit-learn
import matplotlib.pyplot as plt
target_loc = "test.txt" #真实标签所在的文件
traget_data = pd.read_csv(target_loc, sep="\t", names=["loc", "type"])
true_label = [i for i in traget_data["type"]] #真实标签
predict_loc = "pred_result.csv"
predict_data = pd.read_csv(predict_loc, encoding="gbk")
predict_label = predict_data.to_numpy().argmax(axis=1)
predict_score = predict_data.to_numpy().max(axis=1)
#精度
accuracy = accuracy_score(true_label, predict_label) #accuracy_score来自 sklearn
print(f"精度: {accuracy}")
#查准率
precision = precision_score(true_label, predict_label, labels=None, pos_label=1, average='macro')
print(f"查准率:{precision}")
#召回率
recall = recall_score(true_label, predict_label, average='macro')
print(f"召回率:{recall}")
#F1-score
f1 = f1_score(true_label, predict_label, average='macro')
print(f"F1-score:{f1}")
#混淆矩阵
label_names = []
data_root = "enhance_dataset"
for root, dirs, files in os.walk(data_root):
if len(dirs) != 0:
label_names = dirs
confusion = confusion_matrix(true_label, predict_label, labels=[i for i in range(len(label_names))])
plt.matshow(confusion, cmap=plt.cm.Oranges) # Greens, Blues, Oranges, Reds
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams["font.size"] = 8
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
plt.colorbar()
plt.figure(figsize=(10,10),dpi=120)
for i in range(len(confusion)):
for j in range(len(confusion)):
plt.annotate(confusion[j,i], xy=(i, j), horizontalalignment='center', verticalalignment='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
5. 模型优化
5.1 迁移学习与多层学习率
# -*- coding = utf-8 -*-
# @Time : 2024-02-28 15:15
# @Author : 宋俊霖
# @File : 迁移学习.py
# @Software : PyCharm
import time
from torch import nn
from tqdm import tqdm
import torch
from torchvision.models import resnet18
from 生成数据加载器 import LoadData
from torch.utils.tensorboard import SummaryWriter
#训练函数
def train(dataloader, model, loss_fn, optimizer, device):
size = len(dataloader.dataset) #样本数
avg_loss = 0 #初始化平均损失
for batch, (X, y) in tqdm(enumerate(dataloader)): #batch: 序号,代表第几个batch X:图片 y:标签
X, y = X.to(device), y.to(device)
pred = model(X) #预测值
loss = loss_fn(pred, y)#计算每一个batch的 真实标签 和 预测标签 之间的损失
avg_loss += loss #avg_loss将每一个batch的loss累加起来
optimizer.zero_grad() #优化器清零
loss.backward() #反向传播更新模型参数
optimizer.step() #优化器更新参数
#每10个batch输出一次
if batch % 10 == 0:
loss, current = loss.item(), batch * len(X) # loss: 当前的这个batch的loss current:已经处理了多少张图片
print(f"loss:{loss:>7f} [{current:>5d} / {size:>5d}]")
avg_loss /= size #得到每张图片的平均损失
avg_loss = avg_loss.detach().cpu().numpy() # detach():去除梯度信息 cpu():把数据从显卡传回cpu
return avg_loss
#验证函数
def validate(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
model.eval() #把模型转变为验证模式,不用反向传播
avg_loss, correct = 0, 0 #corrct:正确预测的图片数量
with torch.no_grad(): #在进行模型参数计算时,不求梯度值
for X, y in tqdm(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
avg_loss += loss_fn(pred, y).item() #item():提取数值
correct += (pred.argmax(1) == y).type(torch.float).sum().item() #argmax(1):求每一行最大值的索引 True:1 False:0
avg_loss /= size
acc = correct / size #正确率
print(f"correct={correct}, error={(size - correct)}, Accuracy:{(100 * acc):>0.2f}%, Val_loss:{avg_loss:>8f} \n")
return acc, avg_loss
#数据加载器
batch_size = 32
train_data = LoadData("train.txt", True)
val_data = LoadData("test.txt", False)
train_dataloader = torch.utils.data.DataLoader(
dataset=train_data,
num_workers=4,
pin_memory=True,
batch_size=batch_size,
shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_data,
num_workers=4,
pin_memory=True,
batch_size=batch_size,
)
#搭建模型
model = resnet18(pretrained=True) #迁移学习:迁移学习是一种将已训练好的模型(预训练模型)参数迁移到新的模型来帮助新模型训练的技术
model.fc = nn.Linear(model.fc.in_features, 55) #resnet18预训练模型的fc层输出是1000,要替换成55
nn.init.xavier_normal(model.fc.weight)
parms_1x = [value for name, value in model.named_parameters()
if name not in ['fc.weight', 'fc.bias']] #除去fc层外所有层的参数
parms_10x = [value for name, value in model.named_parameters()
if name in ['fc.weight', 'fc.bias']] #fc层的参数
#损失函数
loss_fn = torch.nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-4 #学习率
#分层学习率
optimizer = torch.optim.Adam([
{
'params': parms_1x
},
{
'params': parms_10x,
'lr': learning_rate * 10
}
],lr=learning_rate)
def WriteData(fname, *args):
with open(fname, 'a+') as f:
for data in args:
f.write(str(data)+"\t")
f.write("\n")
if __name__ == '__main__':
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(f"正在使用 {device} device")
model = model.to(device)
epochs = 50
loss_ = 10 #判断当前训练的模型是否最优
save_root = "output/"
writer = SummaryWriter(log_dir='log')
for epoch in range(epochs):
print(f"Epoch {epoch + 1}---------------------\n")
time_start = time.time()
print("开始训练")
avg_loss = train(train_dataloader, model, loss_fn, optimizer, device)
time_end = time.time()
print(f"train time: {(time_end - time_start)}")
#开始验证
print("开始验证")
val_acc, val_loss = validate(val_dataloader, model, loss_fn, device)
writer.add_scalar(tag="准确率", # 可以暂时理解为图像的名字
scalar_value=val_acc, # 纵坐标的值
global_step=epoch+1 # 当前是第几次迭代,可以理解为横坐标的值
)
WriteData(
save_root + "resnet18_pretrain.txt",
"epoch", epoch,
"train_loss", avg_loss,
"val_loss", val_loss,
"val_acc", val_acc
)
if epoch % 5 == 0:
torch.save(model.state_dict(), save_root +
"resnet18_pretrain_epoch" +str(epoch)+"_train_loss_"+str(avg_loss)+".pth")
torch.save(model.state_dict(), save_root+"resnet18_pretrain_last.pth")
if avg_loss < loss_: #训练loss小于 loss_ 就认为当前训练模型最优
loss_ = avg_loss
torch.save(model.state_dict(), save_root+"resnet18_pretrain_best.pth")