迁移学习-freeze和finetune

自定义数据集

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/28 17:08
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : hymData.py
# @Software: PyCharm

import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os


class hymData(Dataset):
	def __init__(self,img_w=128,img_h=128,path="./data/hymenoptera_data",ants_file="ants",bees_file="bees",train=True,preprocess=True):
		"""
		数据初始化
		:param img_w: 缩放图像 宽
		:param img_h: 缩放图像 高
		:param path:
		:param ants_file:
		:param bees_file:
		:param train:
		:param preprocess:
		"""
		
		super(hymData, self).__init__()
		self.img_w = img_w
		self.img_h = img_h
		self.path = path
		if train:
			self.path = self.path+"/train/"
			ants = os.listdir(self.path + ants_file)
			bees = os.listdir(self.path + bees_file)
			ants_len = len(ants)
			self.ants_file_list = {index:[self.path + ants_file+"/"+ants,0] for index,ants in enumerate(ants)}
			self.bees_file_list = {index+ants_len:[self.path + bees_file+"/"+bees,1] for index,bees in enumerate(bees)}
			
			self.tran_x = transforms.Compose([
				transforms.Resize([self.img_w, self.img_h]),
				transforms.ToTensor(),
				transforms.RandomRotation(10),
				transforms.RandomCrop(self.img_w, padding=4),
				transforms.RandomHorizontalFlip(),
				transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
			])
		else:
			self.path = self.path + "/val/"
			ants = os.listdir(self.path + ants_file)
			bees = os.listdir(self.path + bees_file)
			ants_len = len(ants)
			self.ants_file_list = {index: [self.path + ants_file + "/" + ants, 0] for index, ants in enumerate(ants)}
			self.bees_file_list = {index+ants_len: [self.path + bees_file + "/" + bees, 1] for index, bees in enumerate(bees)}
			self.tran_x = transforms.Compose([
				transforms.Resize([self.img_w, self.img_h]),
				transforms.ToTensor(),
				transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
			])
			
		self.train = train
		self.preprocess = preprocess
		# 合并ants和bees为一个字典
		self.file_list = {**self.bees_file_list , **self.ants_file_list}
		
		
		
		
	def __len__(self):
		return len(self.file_list.keys())
	
	def __getitem__(self, index):
		x,y = self.file_list[index]
		x = Image.open(x)
		if self.preprocess:
			x = self.tran_x(x)
		return x,y
		
	

if __name__ == '__main__':
	
	data = hymData(256,256,train=True,preprocess=True)
	# print(data) <__main__.hymData object at 0x000002A937DDAD68>
	# it = iter(data)
	# img,label = next(it) #
	# plt.imshow(img)
	# plt.show()
	# print(data.__len__()) # 训练集244
	# x,y= data.__getitem__(243) # 最后一张图片
	# print(x.shape)
	# plt.imshow(x.numpy().transpose([1,2,0]))
	# plt.show()
	data_loader = DataLoader(dataset=data,shuffle=True,batch_size=5)
	data_loader = iter(data_loader)
	bx,by = next(data_loader)
	print(bx.shape)
	print(by.shape,by)
	plt.imshow(bx[0].numpy().transpose([1, 2, 0]))
	plt.show()
	plt.imshow(bx[1].numpy().transpose([1, 2, 0]))
	plt.show()
	plt.imshow(bx[2].numpy().transpose([1, 2, 0]))
	plt.show()
	plt.imshow(bx[3].numpy().transpose([1, 2, 0]))
	plt.show()
	plt.imshow(bx[4].numpy().transpose([1, 2, 0]))
	plt.show()

迁移学习

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/28 20:39
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : transferTrainer.py
# @Software: PyCharm

import torch
import torch.utils.data as Data
import numpy as np
import matplotlib.pyplot as plt
import os
from hymData import hymData
from torchvision import models
from torch.optim import lr_scheduler

class Trainer(object):
	def __init__(self,lr=0.005,batch_size=32,
	             num_epoch=120,train_data=None,
	             test_data=None,mode="finetune"):
		self.lr = lr
		self.batch_size = batch_size
		self.num_epoch = num_epoch
		self.train_data_loader = Data.DataLoader(dataset=train_data,batch_size=batch_size,
		                                   shuffle=True)
		self.test_data_loader = Data.DataLoader(dataset=test_data,batch_size=batch_size,
		                                        shuffle=True)
		
		self.mode = mode
		self.model_path = "./model"
		# 创建模型
		self.loss = torch.nn.CrossEntropyLoss()
		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		
		if mode is "finetune" : # 进行细微的调整,输出层根据任务调整
			print("微调迁移学习")
			self.model = models.resnet18(pretrained=True)
			
		elif mode is "fixed": # 具有通用性,是一些提取好的特征信息,例如:识别轮廓,纹理
			print("固定表征迁移学习")
			self.model = models.resnet18(pretrained=True)
			for parm in self.model.parameters():
				parm.requires_grad = False
			
		else:
			print("随机迁移学习")
			self.model = models.resnet18(pretrained=False)
		
		# 对迁移学习输出层进行自定义
		num_fc = self.model.fc.in_features
		self.model.fc = torch.nn.Linear(num_fc,2)
		self.model = self.model.to(self.device)
		self.optim = torch.optim.Adam(self.model.parameters(),lr=lr,betas=(0.5,0.99))
		self.exp_lr_sche = lr_scheduler.StepLR(self.optim,step_size=20,gamma=0.1)# 学习率衰减
		
	def train(self):
		# if os.path.exists(self.model_path+"/transfer.pkl"):
		# 	self.model.load_state_dict(torch.load(self.model_path+"/transfer.pkl"))
		# 	print("模型导入成功",self.model_path)
		best_acc = 0
		acc_list = []
		
		for epoch in range(self.num_epoch):
			self.model.train()
			epoch_loss = 0
			for i,(bx,by) in enumerate(self.train_data_loader):
				bx = bx.to(self.device)
				by = by.to(self.device)
				
				pre_logis = self.model(bx)
				pre_y = torch.softmax(pre_logis,dim=1)
				loss = self.loss(pre_y,by)
				
				self.optim.zero_grad()
				loss.backward()
				self.optim.step()
				
				epoch_loss += loss.item()
				
			self.exp_lr_sche.step()
			curr_acc = self.test()
			acc_list.append(curr_acc)
			print("| epoch %d/%d | loss %f | current accuracy %f%%"%(
				epoch,self.num_epoch,epoch_loss,curr_acc
			))
			
			if curr_acc > best_acc:
				best_acc = curr_acc
				print("最佳正确率",best_acc)
				if os.path.exists(self.model_path) is False:
					os.makedirs(self.model_path)
				torch.save(self.model.state_dict(),self.model_path+"/transfer.pkl")
		return acc_list
		
	def test(self):
		acc = 0
		for i,(bx,by) in enumerate(self.test_data_loader):
			bx = bx.to(self.device)
			by = by.to(self.device)
			
			pre_logis = self.model(bx)
			# print(pre_logis)
			# pre_y = torch.softmax(pre_logis,dim=1)
			_,pre_y = torch.max(pre_logis,1)
			# pre_y = np.argmax(pre_logis.data.cpu,axis=0)
			# print(pre_y)
			acc += torch.sum(pre_y==by.data)
			
		acc = acc.double() / self.test_data_loader.dataset.__len__() * 100
		return acc.item()
		

if __name__ == '__main__':
	train_data = hymData(128,128,train=True)
	test_data = hymData(128,128,train=False)
	num_epoch = 50
	batch_size = 64
	lr=0.00001
	torch.cuda.empty_cache() # 清空缓存
	trainer = Trainer(lr=lr,
	                batch_size=batch_size,
	                num_epoch=num_epoch,train_data=train_data,
	                test_data=test_data,mode="finetune")
	acc_list_finetune = trainer.train()
	
	torch.cuda.empty_cache()
	trainer = Trainer(lr=lr,
	                batch_size=batch_size,
	                num_epoch=num_epoch, train_data=train_data,
	                test_data=test_data, mode="fixed")
	acc_list_fixed= trainer.train()
	
	torch.cuda.empty_cache()
	trainer = Trainer(lr=lr,
	                batch_size=batch_size,
	                num_epoch=num_epoch, train_data=train_data,
	                test_data=test_data, mode="")
	acc_list_other = trainer.train()
	
	x = range(num_epoch)
	plt.figure()
	plt.plot(x,acc_list_finetune,label="finetune")
	plt.plot(x,acc_list_fixed,label="fixed")
	plt.plot(x,acc_list_other,label="random")
	plt.title("transfer:fixed vs finetune vs random , lr"+str(lr))
	plt.xticks(x)
	plt.legend()
	plt.savefig("./saved/transfer_acc.jpg")
	plt.show()

总结

学习率小点的时候适合finetune
学习率大点的时候适合freeze

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

语音不识别

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值