B站项目-基于Pytorch的ResNet垃圾图片分类

基于Pytorch的ResNet垃圾图片分类

项目链接
数据集下载链接

1. 数据集预处理

1.1 画图片的宽高分布散点图
import os

import matplotlib.pyplot as plt
import PIL.Image as Image


def plot_resolution(dataset_root_path):
    image_size_list = []#存放图片尺寸
    for root, dirs, files in os.walk(dataset_root_path):
        for file in files:
            image_full_path = os.path.join(root, file)
            image = Image.open(image_full_path)
            image_size = image.size
            image_size_list.append(image_size)

    print(image_size_list)

    image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽
    image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高

    plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体
    plt.rcParams['font.size'] = 8
    plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题

    plt.scatter(image_width_list, image_height_list, s=1)
    plt.xlabel('宽')
    plt.ylabel('高')
    plt.title('图像宽高分布散点图')
    plt.show()



if __name__ == '__main__':
    dataset_root_path = "F:\数据与代码\dataset"
    plot_resolution(dataset_root_path)

运行结果:
运行结果

注: os.walk详细解释参考

1.2 画出数据集的各个类别图片数量的条形图

文件组织结构:
在这里插入图片描述

def plot_bar(dataset_root_path):

    file_name_list = []
    file_num_list = []

    for root, dirs, files in os.walk(dataset_root_path):
        if len(dirs) != 0 :
            for dir in dirs:
                file_name_list.append(dir)
        file_num_list.append(len(files))



    file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23,  25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]
    #[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]


    mean = np.mean(file_num_list)
    print("mean= ", mean)

    bar_positions = np.arange(len(file_name_list))
    fig, ax = plt.subplots()
    ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度
    ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
    plt.rcParams['font.size'] = 8
    plt.rcParams['axes.unicode_minus'] = False  # 解决图像中的负号乱码问题

    ax.set_xticks(bar_positions)#设置x轴的刻度
    ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签
    ax.set_ylabel("类别数量")
    ax.set_title("各个类别数量分布散点图")
    plt.show()

运行结果
运行结果

1.3 删除宽高有问题的图片
import os
import PIL.Image as Image


MIN = 200
MAX = 2000
ratio = 0.5

def delete_img(dataset_root_path):
    delete_img_list = [] #需要删除的图片地址

    for root, dirs, files in os.walk(dataset_root_path):

        for file in files:
            img_full_path = os.path.join(root, file)
            img = Image.open(img_full_path)
            img_size = img.size
            max_l = img_size[0] if img_size[0] > img_size[1] else img_size[1]
            min_l = img_size[0] if img_size[0] < img_size[1] else img_size[1]
            # 把图片宽高限制在 200~2000 这里可能会重复添加图片路径
            if img_size[0] < MIN or img_size[1] < MIN:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)

            elif img_size[0] > MAX or img_size[1] > MAX:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)

            #避免图片窄长
            elif min_l / max_l < ratio:
                delete_img_list.append(img_full_path)
                print("不满足要求", img_full_path, img_size)


    for img in delete_img_list:
        print("正在删除", img)
        os.remove(img)



if __name__ == '__main__':
    dataset_root_img = 'F:\数据与代码\dataset'
    delete_img(dataset_root_img)

再次运行1.1 和1.2的代码得到处理后的数据集宽高分布和类别数量
处理后的宽高分布

处理后的类别数量

1.4 数据增强
import os

import cv2

#水平翻转
import numpy as np


def Horizontal(image):
    return cv2.flip(image, 1, dst=None)

#垂直翻转
def Vertical(image):
    return cv2.flip(image, 0, dst=None)

threshold = 200 #阈值

#数据增强
def data_augmentation(from_root_path, save_root_path):
    for root, dirs, files in os.walk(from_root_path):
            for file in files:
                img_full_path = os.path.join(root, file)
                split = os.path.split(img_full_path)
                save_path = os.path.join(save_root_path, os.path.split(split[0])[1])
                print(save_path)
                if os.path.isdir(save_path) == False:#文件夹不存在就创建
                    os.makedirs(save_path)

                img = cv2.imdecode(np.fromfile(img_full_path, dtype=np.uint8), -1)#读取含中文的路径
                cv2.imencode('.jpg', img)[1].tofile(os.path.join(save_path,file[:-5]+ "_original.jpg")) #保存原图


                if len(files) > 0 and len(files) < threshold:  # 类别数量小于阈值,需要对该类别的所有图片进行数据增强
                    img_horizontal = Horizontal(img)
                    cv2.imencode('.jpg', img_horizontal)[1].tofile(os.path.join(save_path, file[:-5] + "_horizontal.jpg"))
                    img_vertical = Vertical(img)
                    cv2.imencode('.jpg', img_vertical)[1].tofile(os.path.join(save_path, file[:-5] + "_vertical.jpg"))
                else:
                    pass

if __name__ == '__main__':
    from_root_path = 'F:\数据与代码\dataset'
    save_root_path = 'F:\数据与代码\enhance_dataset'
    data_augmentation(from_root_path, save_root_path)


进行数据增强

1.5 数据集平衡处理

将图片数量超过阈值的类别删除一部分图片

import os
import random

threshold = 300
def dataset_balance(dataset_root_path):

    for root, dirs, files in os.walk(dataset_root_path):
        if len(files) > threshold:
            delete_img_list = []
            for file in files:
                img_full_path = os.path.join(root, file)
                delete_img_list.append(img_full_path)

            random.shuffle(delete_img_list)
            delete_img_list = delete_img_list[threshold:]
            for img in delete_img_list:
                os.remove(img)
                print("成功删除", img)

if __name__ == '__main__':
    dataset_root_path = 'F:\数据与代码\enhance_dataset'
    dataset_balance(dataset_root_path)

数据集平衡处理

1.6 求图像的均值和方差
from torchvision import transforms as T
import torch
from torchvision.datasets import ImageFolder
from tqdm import tqdm

transform = T.Compose([
    T.RandomResizedCrop(224),#随机采样并缩放为 224X224
    T.ToTensor(),
])


def getStat(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
    )

    #均值 方差
    mean = torch.zeros(3)#三维
    std = torch.zeros(3)

    for X, _ in tqdm(train_loader):# tqdm添加进度条
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()

    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root='F:/数据与代码/enhance_dataset', transform=transform)
    print(getStat(train_dataset))

2. 生成数据集与数据加载器

2.1 生成数据集
import os
import random

train_ratio = 0.9
test_ratio = 1 - train_ratio

root_data = 'F:\数据与代码\enhance_dataset'

train_list, test_list = [], []

class_flag = -1
for root, dirs, files in os.walk(root_data):
    for i in range(0, int(len(files)*train_ratio)):
        train_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        train_list.append(train_data)

    for i in range(int(len(files)*train_ratio), len(files)):
        test_data = os.path.join(root, files[i]) + '\t' + str(class_flag) + '\n'
        test_list.append(test_data)

    class_flag += 1

random.shuffle(train_list)
random.shuffle(test_list)

with open('train.txt', 'w', encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt', 'w', encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(str(test_img))

在这里插入图片描述

2.2 生成数据加载器
import torch
from PIL import Image
import torchvision.transforms as transforms

#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True



from torch.utils.data import Dataset



#数据归一化与标准化
transform_BZ = transforms.Normalize(
    mean = [0.64148515, 0.57362735, 0.5084857],
    std = [0.21153161, 0.21981773, 0.22988321]
)


class LoadData(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag
        self.img_size = 512
        self.train_tf = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.RandomHorizontalFlip(),#随机水平翻转
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transform_BZ#数据归一化与标准化
        ])

        self.val_tf = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transform_BZ  # 数据归一化与标准化
        ])

    def get_images(self, txt_path):#返回格式[路径, 标签]
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            #map(函数,参数)
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
        return imgs_info

    def padding_black(self, img):  # 如果尺寸太小可以扩充
        w, h = img.size
        scale = self.img_size / max(w, h)
        img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
        size_fg = img_fg.size
        size_bg = self.img_size
        img_bg = Image.new("RGB", (size_bg, size_bg))
        img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                              (size_bg - size_fg[1]) // 2))
        img = img_bg
        return img

    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img = Image.open(img_path)
        img = img.convert('RGB')#转换为RGB格式
        img = self.padding_black(img)
        if self.train_flag:
            img = self.train_tf(img)

        else:
            img = self.val_tf(img)

        label = int(label)

        return img, label


    def __len__(self):
        return len(self.imgs_info)

if __name__ == '__main__':
    train_dataset = LoadData('train.txt', True)
    print("数据个数", len(train_dataset))

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=5,
        shuffle=True
    )

    for image, label in train_loader:
        print("image.shape", image.shape)
        # print(image)
        print(label)

3. 模型搭建与训练

# -*- coding = utf-8 -*-
# @Time : 2024-02-28 15:15
# @Author : 宋俊霖
# @File : 搭建模型与训练函数.py
# @Software : PyCharm
import time
from tqdm import tqdm
import torch
from torchvision.models import resnet18
from 生成数据加载器 import LoadData



#搭建模型
model = resnet18(num_classes=55)#55种分类

#训练函数
def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset) #样本数
    avg_loss  = 0 #初始化平均损失
    for batch, (X, y) in tqdm(enumerate(dataloader)): #batch: 序号,代表第几个batch  X:图片 y:标签
        X, y = X.to(device), y.to(device)
        pred = model(X) #预测值
        loss = loss_fn(pred, y)#计算每一个batch的 真实标签 和 预测标签 之间的损失
        avg_loss += loss #avg_loss将每一个batch的loss累加起来
        optimizer.zero_grad() #优化器清零
        loss.backward() #反向传播更新模型参数
        optimizer.step() #优化器更新参数

        #每10个batch输出一次
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X) # loss: 当前的这个batch的loss  current:已经处理了多少张图片
            print(f"loss:{loss:>7f} [{current:>5d} / {size:>5d}]")

    avg_loss /= size #得到每张图片的平均损失
    avg_loss = avg_loss.detach().cpu().numpy() # detach():去除梯度信息  cpu():把数据从显卡传回cpu
    return avg_loss


#验证函数
def validate(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    model.eval() #把模型转变为验证模式,不用反向传播
    avg_loss, correct = 0, 0 #corrct:正确预测的图片数量
    with torch.no_grad(): #在进行模型参数计算时,不求梯度值
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            avg_loss += loss_fn(pred, y).item() #item():提取数值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() #argmax(1):求每一行最大值的索引 True:1 False:0

    avg_loss /= size
    acc = correct / size #正确率
    print(f"correct={correct}, error={(size - correct)}, Accuracy:{(100 * acc):>0.2f}%, Val_loss:{avg_loss:>8f} \n")
    return acc, avg_loss


#数据加载器
batch_size = 32
train_data = LoadData("train.txt", True)
val_data = LoadData("test.txt", False)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_data,
    num_workers=4,
    pin_memory=True,
    batch_size=batch_size,
    shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_data,
    num_workers=4,
    pin_memory=True,
    batch_size=batch_size,
)


#损失函数
loss_fn = torch.nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-3 #学习率
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)


def WriteData(fname, *args):
    with open(fname, 'a+') as f:
        for data in args:
            f.write(str(data)+"\t")
        f.write("\n")


if __name__ == '__main__':

    device = "cuda:2" if torch.cuda.is_available() else "cpu"
    print(f"正在使用 {device} device")

    model = model.to(device)

    epochs = 50
    loss_ = 10 #判断当前训练的模型是否最优
    save_root = "output/"

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}---------------------\n")
        time_start = time.time()
        print("开始训练")
        avg_loss = train(train_dataloader, model, loss_fn, optimizer, device)
        time_end = time.time()
        print(f"train time: {(time_end - time_start)}")


        #开始验证
        print("开始验证")
        val_acc, val_loss = validate(val_dataloader, model, loss_fn, device)

        WriteData(
            save_root + "resnet18_no_pretrain.txt",
            "epoch", epoch,
            "train_loss", avg_loss,
            "val_loss", val_loss,
            "val_acc", val_acc
        )

        if epoch % 5 == 0:
            torch.save(model.state_dict(), save_root +
                       "resnet18_no_pretrain_epoch" +str(epoch)+"_train_loss_"+str(avg_loss)+".pth")
        torch.save(model.state_dict(), save_root+"resnet18_no_pretrain_last.pth")

        if avg_loss < loss_: #训练loss小于 loss_ 就认为当前训练模型最优
            loss_ = avg_loss
            torch.save(model.state_dict(), save_root+"resnet18_no_pretrain_best.pth")

在这里插入图片描述

4. 模型测试

4.1 单张图片模型预测
# -*- coding = utf-8 -*-
# @Time : 2024-02-28 19:25
# @Author : 宋俊霖
# @File : 单张图片模型预测.py
# @Software : PyCharm
import os
import torchvision.transforms as transforms
from PIL import Image
#遇到格式损坏的文件就跳过
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
from torchvision.models import resnet18


#数据归一化与标准化
transform_BZ = transforms.Normalize(
    mean = [0.64148515, 0.57362735, 0.5084857],
    std = [0.21153161, 0.21981773, 0.22988321]
)


def padding_black(img, img_size = 512):  # 如果尺寸太小可以扩充
    w, h = img.size
    scale = img_size / max(w, h)
    img_fg = img.resize([int(x) for x in [w * scale, h * scale]])
    size_fg = img_fg.size
    size_bg = img_size
    img_bg = Image.new("RGB", (size_bg, size_bg))
    img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,
                          (size_bg - size_fg[1]) // 2))
    img = img_bg
    return img

if __name__ == '__main__':

    # img_path = 'test_dataset/img_骨肉相连_8.jpeg'
    # img_path = 'test_dataset/img_电池_20.jpeg'
    # img_path = 'test_dataset/img_火龙果_5.jpeg'
    # img_path = 'test_dataset/img_口罩_10.jpeg'
    img_path = 'test_dataset/草莓.png'
    img_size = 512
    test_tf = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transform_BZ  # 数据归一化与标准化
    ])

    device = "cuda:2" if torch.cuda.is_available() else "cpu"
    print(f"正在使用 {device} device")

    model = resnet18(num_classes=55).to(device)

    state_dict = torch.load("output/resnet18_no_pretrain_best.pth")
    model.load_state_dict(state_dict)

    model.eval()
    with torch.no_grad():
        img = Image.open(img_path).convert('RGB')
        img = padding_black(img)
        img = test_tf(img)
        img_tensor = torch.unsqueeze(img, 0) #将C,H,W -> N,C,H,W

        img_tensor = img_tensor.to(device)
        res = model(img_tensor)

        id = res.argmax(1).item()

        for root, dirs, files in os.walk("enhance_dataset"):
            if len(dirs) != 0:
                print("预测结果是: ", dirs[id])



4.2 在测试集上预测
import os

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from tqdm import tqdm
import pandas as pd
from 生成数据加载器 import LoadData


def test(dataloader, model, device):
    pred_list = []
    model.eval()

    with torch.no_grad():

        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            pred_softmax = torch.softmax(pred, 1).cpu().numpy()
            pred_list.append(pred_softmax.tolist()[0])
    return pred_list


def WriteData(fname, *args):
    with open(fname, 'a+') as f:
        for data in args:
            f.write(str(data)+"\t")
        f.write("\n")


if __name__ == '__main__':
    batch_size = 1

    test_data = LoadData("test.txt", False)

    test_dataloader = DataLoader(
        dataset=test_data,
        num_workers=4,
        pin_memory=True,
        batch_size=batch_size
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("device:{device}")

    model = resnet18(num_classes=55)
    model.load_state_dict(torch.load("output/resnet18_pretrain_best.pth"))
    model.to(device)


    pred_list = test(test_dataloader, model, device)
    print("pred_list", pred_list)

    file_name_list = []
    data_root = "enhance_dataset"
    for root, dirs, files in os.walk(data_root):
        if len(dirs) != 0:
            file_name_list = dirs

    df_pred = pd.DataFrame(data=pred_list, columns=file_name_list)

    df_pred.to_csv('pred_result.csv', encoding='gbk', index=False)
4.3 计算精度、查准率、召回率、F1-score并绘制混淆矩阵
import os

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from tqdm import tqdm
import pandas as pd
from 生成数据加载器 import LoadData
from sklearn.metrics import * #pip install scikit-learn
import matplotlib.pyplot as plt



target_loc = "test.txt" #真实标签所在的文件
traget_data = pd.read_csv(target_loc, sep="\t", names=["loc", "type"])
true_label = [i for i in traget_data["type"]] #真实标签

predict_loc = "pred_result.csv"
predict_data = pd.read_csv(predict_loc, encoding="gbk")
predict_label = predict_data.to_numpy().argmax(axis=1)
predict_score = predict_data.to_numpy().max(axis=1)

#精度
accuracy = accuracy_score(true_label, predict_label) #accuracy_score来自 sklearn
print(f"精度: {accuracy}")

#查准率
precision = precision_score(true_label, predict_label, labels=None, pos_label=1, average='macro')
print(f"查准率:{precision}")

#召回率
recall = recall_score(true_label, predict_label, average='macro')
print(f"召回率:{recall}")

#F1-score
f1 = f1_score(true_label, predict_label, average='macro')
print(f"F1-score:{f1}")


#混淆矩阵
label_names = []
data_root = "enhance_dataset"
for root, dirs, files in os.walk(data_root):
    if len(dirs) != 0:
        label_names = dirs

confusion = confusion_matrix(true_label, predict_label, labels=[i for i in range(len(label_names))])


plt.matshow(confusion, cmap=plt.cm.Oranges)   # Greens, Blues, Oranges, Reds

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams["font.size"] = 8
plt.rcParams["axes.unicode_minus"] = False  # 该语句解决图像中的“-”负号的乱码问题

plt.colorbar()
plt.figure(figsize=(10,10),dpi=120)
for i in range(len(confusion)):
    for j in range(len(confusion)):
        plt.annotate(confusion[j,i], xy=(i, j), horizontalalignment='center', verticalalignment='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')

在这里插入图片描述
在这里插入图片描述

5. 模型优化

5.1 迁移学习与多层学习率
# -*- coding = utf-8 -*-
# @Time : 2024-02-28 15:15
# @Author : 宋俊霖
# @File : 迁移学习.py
# @Software : PyCharm
import time

from torch import nn
from tqdm import tqdm
import torch
from torchvision.models import resnet18
from 生成数据加载器 import LoadData


from torch.utils.tensorboard import SummaryWriter





#训练函数
def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset) #样本数
    avg_loss  = 0 #初始化平均损失
    for batch, (X, y) in tqdm(enumerate(dataloader)): #batch: 序号,代表第几个batch  X:图片 y:标签
        X, y = X.to(device), y.to(device)
        pred = model(X) #预测值
        loss = loss_fn(pred, y)#计算每一个batch的 真实标签 和 预测标签 之间的损失
        avg_loss += loss #avg_loss将每一个batch的loss累加起来
        optimizer.zero_grad() #优化器清零
        loss.backward() #反向传播更新模型参数
        optimizer.step() #优化器更新参数

        #每10个batch输出一次
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(X) # loss: 当前的这个batch的loss  current:已经处理了多少张图片
            print(f"loss:{loss:>7f} [{current:>5d} / {size:>5d}]")

    avg_loss /= size #得到每张图片的平均损失
    avg_loss = avg_loss.detach().cpu().numpy() # detach():去除梯度信息  cpu():把数据从显卡传回cpu
    return avg_loss


#验证函数
def validate(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    model.eval() #把模型转变为验证模式,不用反向传播
    avg_loss, correct = 0, 0 #corrct:正确预测的图片数量
    with torch.no_grad(): #在进行模型参数计算时,不求梯度值
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            avg_loss += loss_fn(pred, y).item() #item():提取数值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() #argmax(1):求每一行最大值的索引 True:1 False:0

    avg_loss /= size
    acc = correct / size #正确率
    print(f"correct={correct}, error={(size - correct)}, Accuracy:{(100 * acc):>0.2f}%, Val_loss:{avg_loss:>8f} \n")
    return acc, avg_loss


#数据加载器
batch_size = 32
train_data = LoadData("train.txt", True)
val_data = LoadData("test.txt", False)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_data,
    num_workers=4,
    pin_memory=True,
    batch_size=batch_size,
    shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    dataset=val_data,
    num_workers=4,
    pin_memory=True,
    batch_size=batch_size,
)

#搭建模型
model = resnet18(pretrained=True) #迁移学习:迁移学习是一种将已训练好的模型(预训练模型)参数迁移到新的模型来帮助新模型训练的技术
model.fc = nn.Linear(model.fc.in_features, 55) #resnet18预训练模型的fc层输出是1000,要替换成55
nn.init.xavier_normal(model.fc.weight)


parms_1x = [value for name, value in model.named_parameters()
            if name not in ['fc.weight', 'fc.bias']] #除去fc层外所有层的参数

parms_10x = [value for name, value in model.named_parameters()
            if name in ['fc.weight', 'fc.bias']] #fc层的参数


#损失函数
loss_fn = torch.nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-4 #学习率
#分层学习率
optimizer = torch.optim.Adam([
    {
        'params': parms_1x
    },
    {
        'params': parms_10x,
        'lr': learning_rate * 10
    }
],lr=learning_rate)


def WriteData(fname, *args):
    with open(fname, 'a+') as f:
        for data in args:
            f.write(str(data)+"\t")
        f.write("\n")


if __name__ == '__main__':

    device = "cuda:2" if torch.cuda.is_available() else "cpu"
    print(f"正在使用 {device} device")

    model = model.to(device)

    epochs = 50
    loss_ = 10 #判断当前训练的模型是否最优
    save_root = "output/"

    writer = SummaryWriter(log_dir='log')
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}---------------------\n")
        time_start = time.time()
        print("开始训练")
        avg_loss = train(train_dataloader, model, loss_fn, optimizer, device)
        time_end = time.time()
        print(f"train time: {(time_end - time_start)}")


        #开始验证
        print("开始验证")
        val_acc, val_loss = validate(val_dataloader, model, loss_fn, device)

        writer.add_scalar(tag="准确率",  # 可以暂时理解为图像的名字
                          scalar_value=val_acc,  # 纵坐标的值
                          global_step=epoch+1  # 当前是第几次迭代,可以理解为横坐标的值
                          )

        WriteData(
            save_root + "resnet18_pretrain.txt",
            "epoch", epoch,
            "train_loss", avg_loss,
            "val_loss", val_loss,
            "val_acc", val_acc
        )

        if epoch % 5 == 0:
            torch.save(model.state_dict(), save_root +
                       "resnet18_pretrain_epoch" +str(epoch)+"_train_loss_"+str(avg_loss)+".pth")
        torch.save(model.state_dict(), save_root+"resnet18_pretrain_last.pth")

        if avg_loss < loss_: #训练loss小于 loss_ 就认为当前训练模型最优
            loss_ = avg_loss
            torch.save(model.state_dict(), save_root+"resnet18_pretrain_best.pth")

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 9
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值