自定义数据集
# -*- I Love Python!!! And You? -*-
# @Time : 2022/3/28 17:08
# @Author : sunao
# @Email : 939419697@qq.com
# @File : hymData.py
# @Software: PyCharm
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
class hymData(Dataset):
def __init__(self,img_w=128,img_h=128,path="./data/hymenoptera_data",ants_file="ants",bees_file="bees",train=True,preprocess=True):
"""
数据初始化
:param img_w: 缩放图像 宽
:param img_h: 缩放图像 高
:param path:
:param ants_file:
:param bees_file:
:param train:
:param preprocess:
"""
super(hymData, self).__init__()
self.img_w = img_w
self.img_h = img_h
self.path = path
if train:
self.path = self.path+"/train/"
ants = os.listdir(self.path + ants_file)
bees = os.listdir(self.path + bees_file)
ants_len = len(ants)
self.ants_file_list = {index:[self.path + ants_file+"/"+ants,0] for index,ants in enumerate(ants)}
self.bees_file_list = {index+ants_len:[self.path + bees_file+"/"+bees,1] for index,bees in enumerate(bees)}
self.tran_x = transforms.Compose([
transforms.Resize([self.img_w, self.img_h]),
transforms.ToTensor(),
transforms.RandomRotation(10),
transforms.RandomCrop(self.img_w, padding=4),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
else:
self.path = self.path + "/val/"
ants = os.listdir(self.path + ants_file)
bees = os.listdir(self.path + bees_file)
ants_len = len(ants)
self.ants_file_list = {index: [self.path + ants_file + "/" + ants, 0] for index, ants in enumerate(ants)}
self.bees_file_list = {index+ants_len: [self.path + bees_file + "/" + bees, 1] for index, bees in enumerate(bees)}
self.tran_x = transforms.Compose([
transforms.Resize([self.img_w, self.img_h]),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.train = train
self.preprocess = preprocess
# 合并ants和bees为一个字典
self.file_list = {**self.bees_file_list , **self.ants_file_list}
def __len__(self):
return len(self.file_list.keys())
def __getitem__(self, index):
x,y = self.file_list[index]
x = Image.open(x)
if self.preprocess:
x = self.tran_x(x)
return x,y
if __name__ == '__main__':
data = hymData(256,256,train=True,preprocess=True)
# print(data) <__main__.hymData object at 0x000002A937DDAD68>
# it = iter(data)
# img,label = next(it) #
# plt.imshow(img)
# plt.show()
# print(data.__len__()) # 训练集244
# x,y= data.__getitem__(243) # 最后一张图片
# print(x.shape)
# plt.imshow(x.numpy().transpose([1,2,0]))
# plt.show()
data_loader = DataLoader(dataset=data,shuffle=True,batch_size=5)
data_loader = iter(data_loader)
bx,by = next(data_loader)
print(bx.shape)
print(by.shape,by)
plt.imshow(bx[0].numpy().transpose([1, 2, 0]))
plt.show()
plt.imshow(bx[1].numpy().transpose([1, 2, 0]))
plt.show()
plt.imshow(bx[2].numpy().transpose([1, 2, 0]))
plt.show()
plt.imshow(bx[3].numpy().transpose([1, 2, 0]))
plt.show()
plt.imshow(bx[4].numpy().transpose([1, 2, 0]))
plt.show()
迁移学习
# -*- I Love Python!!! And You? -*-
# @Time : 2022/3/28 20:39
# @Author : sunao
# @Email : 939419697@qq.com
# @File : transferTrainer.py
# @Software: PyCharm
import torch
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
import os
from hymData import hymData
from torchvision import models
from torch.optim import lr_scheduler
class Trainer(object):
def __init__(self,lr=0.005,batch_size=32,
num_epoch=120,train_data=None,
test_data=None,mode="finetune"):
self.lr = lr
self.batch_size = batch_size
self.num_epoch = num_epoch
self.train_data_loader = Data.DataLoader(dataset=train_data,batch_size=batch_size,
shuffle=True)
self.test_data_loader = Data.DataLoader(dataset=test_data,batch_size=batch_size,
shuffle=True)
self.mode = mode
self.model_path = "./model"
# 创建模型
self.loss = torch.nn.CrossEntropyLoss()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if mode is "finetune" : # 进行细微的调整,输出层根据任务调整
print("微调迁移学习")
self.model = models.resnet18(pretrained=True)
elif mode is "fixed": # 具有通用性,是一些提取好的特征信息,例如:识别轮廓,纹理
print("固定表征迁移学习")
self.model = models.resnet18(pretrained=True)
for parm in self.model.parameters():
parm.requires_grad = False
else:
print("随机迁移学习")
self.model = models.resnet18(pretrained=False)
# 对迁移学习输出层进行自定义
num_fc = self.model.fc.in_features
self.model.fc = torch.nn.Linear(num_fc,2)
self.model = self.model.to(self.device)
self.optim = torch.optim.Adam(self.model.parameters(),lr=lr,betas=(0.5,0.99))
self.exp_lr_sche = lr_scheduler.StepLR(self.optim,step_size=20,gamma=0.1)# 学习率衰减
def train(self):
# if os.path.exists(self.model_path+"/transfer.pkl"):
# self.model.load_state_dict(torch.load(self.model_path+"/transfer.pkl"))
# print("模型导入成功",self.model_path)
best_acc = 0
acc_list = []
for epoch in range(self.num_epoch):
self.model.train()
epoch_loss = 0
for i,(bx,by) in enumerate(self.train_data_loader):
bx = bx.to(self.device)
by = by.to(self.device)
pre_logis = self.model(bx)
pre_y = torch.softmax(pre_logis,dim=1)
loss = self.loss(pre_y,by)
self.optim.zero_grad()
loss.backward()
self.optim.step()
epoch_loss += loss.item()
self.exp_lr_sche.step()
curr_acc = self.test()
acc_list.append(curr_acc)
print("| epoch %d/%d | loss %f | current accuracy %f%%"%(
epoch,self.num_epoch,epoch_loss,curr_acc
))
if curr_acc > best_acc:
best_acc = curr_acc
print("最佳正确率",best_acc)
if os.path.exists(self.model_path) is False:
os.makedirs(self.model_path)
torch.save(self.model.state_dict(),self.model_path+"/transfer.pkl")
return acc_list
def test(self):
acc = 0
for i,(bx,by) in enumerate(self.test_data_loader):
bx = bx.to(self.device)
by = by.to(self.device)
pre_logis = self.model(bx)
# print(pre_logis)
# pre_y = torch.softmax(pre_logis,dim=1)
_,pre_y = torch.max(pre_logis,1)
# pre_y = np.argmax(pre_logis.data.cpu,axis=0)
# print(pre_y)
acc += torch.sum(pre_y==by.data)
acc = acc.double() / self.test_data_loader.dataset.__len__() * 100
return acc.item()
if __name__ == '__main__':
train_data = hymData(128,128,train=True)
test_data = hymData(128,128,train=False)
num_epoch = 50
batch_size = 64
lr=0.00001
torch.cuda.empty_cache() # 清空缓存
trainer = Trainer(lr=lr,
batch_size=batch_size,
num_epoch=num_epoch,train_data=train_data,
test_data=test_data,mode="finetune")
acc_list_finetune = trainer.train()
torch.cuda.empty_cache()
trainer = Trainer(lr=lr,
batch_size=batch_size,
num_epoch=num_epoch, train_data=train_data,
test_data=test_data, mode="fixed")
acc_list_fixed= trainer.train()
torch.cuda.empty_cache()
trainer = Trainer(lr=lr,
batch_size=batch_size,
num_epoch=num_epoch, train_data=train_data,
test_data=test_data, mode="")
acc_list_other = trainer.train()
x = range(num_epoch)
plt.figure()
plt.plot(x,acc_list_finetune,label="finetune")
plt.plot(x,acc_list_fixed,label="fixed")
plt.plot(x,acc_list_other,label="random")
plt.title("transfer:fixed vs finetune vs random , lr"+str(lr))
plt.xticks(x)
plt.legend()
plt.savefig("./saved/transfer_acc.jpg")
plt.show()
总结
学习率小点的时候适合finetune
学习率大点的时候适合freeze