项目场景:
一个项目学会pytorch
动物二分类
网络搭建和迁移学习
动物二分类
1.数据预处理
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import os
import random
from PIL import Image
class hym_data(Dataset):
def __init__(self,img_h=256,img_w=256,path="./data/hyma_data",
mode='train',preprocess=True):
self.mode = mode
self.img_h = 256
self.img_w = 256
self.path = path
self.preprocess = preprocess
if self.mode is 'train':
self.path = self.path+'/train'
self.transform = transforms.Compose([
transforms.Resize(size=(self.img_h,self.img_w)),
transforms.RandomRotation(15),
transforms.RandomCrop(self.img_w,padding='4'),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
else:
self.path = self.path+'/val'
self.transform = transforms.Compose([
transforms.Resize(size=(self.img_h,self.img_w)),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
self.file_list = self.get_filename_list()
random.shuffle(self.file_list)
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
img_name = self.file_list[item]
if 'ants' in img_name:
lable = 1
else:
lable = 0
img = Image.open(self.path+'/'+img_name)
if self.preprocess:
img = self.transform(img)
return img , lable
def get_filename_list(self):
file_ants = os.listdir(self.path+'/ants')
file_bees = os.listdir(self.path+'/bees')
file_list = ['ants/'+file for file in file_ants] + ['bees/'+file for file in file_bees]
return file_list
if __name__ == '__main__':
hym = hym_data(mode='val',preprocess=True)
it = iter(hym)
img,lable = next(it)
print(img)
print(lable)
2.网络搭建和迁移学习
import torch
from torchvision import models
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
from torch.optim import lr_scheduler
from hym_data import *
class Trainer(object):
def __init__(self, lr=0.005, batch_size=32, num_epoch=64,
train_data=None, test_data=None, mode="finetune"):
self.lr = 0.005
self.batch_size = batch_size
self.num_epoch = num_epoch
self.mode_path = "./mode1"
self.data_loader = Data.DataLoader(dataset=train_data,
batch_size=self.batch_size,
shuffle=True,
num_workers=0)
self.test_loader = Data.DataLoader(dataset=test_data,
batch_size=self.batch_size,
shuffle=True,
num_workers=0)
self.loss = torch.nn.CrossEntropyLoss()
self.device = torch.device("cuda" if torch.cuda.is_available()
else "cpu")
if mode is "finetune":
print("wei tiao xue xi")
self.model = models.resnet18(pretrained=True)
elif mode is "fixed":
print("gu ding biao xue xi")
self.model = models.resnet18(pretrained=True)
for parm in self.mode.parameters():
parm.requires_grad = False
else:
self.model = models.resnet18(pretrained=False)
num_fc = self.model.fc.in_features
self.model.fc = torch.nn.Linear(num_fc, 2)
self.model = self.model.to(self.device)
self.optim = torch.optim.Adam(self.model.parameters(),
lr=self.lr,
betas=(0.5, 0.99))
self.lr_sche = lr_scheduler.StepLR(self.optim,
step_size=20,
gamma=0.1)
def train(self):
best_acc = []
acc_list = []
for epoch in range(self.num_epoch):
self.model.train()
epoch_loss = 0
for i, (bx, by) in enumerate(self.data_loader):
bx = bx.to(self.device)
by = by.to(self.device)
pre_y = self.model(bx)
loss = self.loss(input=pre_y, target=by)
self.optim.zero_grad()
loss.backward()
self.optim.step()
epoch_loss += loss.item()
self.lr_sche.step()
curr_acc = self.test()
acc_list.append(curr_acc)
print("epoch :", epoch, "sun shi zhi :", epoch_loss,
"ce shi zheng que lv :", curr_acc)
if curr_acc > best_acc:
best_acc = curr_acc
print("zheng que lv :", best_acc)
if os.path.exists(self.mode_path) is False:
os.makedirs(self.mode_path)
torch.save(self.model.state_dict(), self.mode_path + "/transfer.pkl")
return acc_list
def test(self):
acc = 0
for i, (bx, by) in enumerate(self.test_loader):
bx = bx.to(self.device)
by = by.to(self.device)
pred = self.model(bx)
_, preds = torch.max(pred, 1)
acc += torch.sum(preds == by.data)
acc = acc.double() / self.test_loader.dataset.__len__()
return acc.item()
3.模型训练和绘图
if __name__ == '__main__':
print("1========================")
train_data = hym_data(img_h=128, img_w=128, mode="train", preprocess=True)
test_data = hym_data(img_h=128, img_w=128, mode="val", preprocess=True)
torch.cuda.empty_cache()
print("2========================")
lr = 0.005
batch_size = 64
num_epoch = 32
print("3=========================")
trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
train_data=train_data, test_data=test_data,
mode="finetune")
acc_list_finetune = trainer.train()
torch.cuda.empty_cache()
trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
train_data=train_data, test_data=test_data,
mode="fiexd")
acc_list_fiexd = trainer.train()
torch.cuda.empty_cache()
trainer = Trainer(lr=lr, batch_size=batch_size, num_epoch=num_epoch,
train_data=train_data, test_data=test_data,
mode="other")
acc_list_other = trainer.train()
x = range(num_epoch)
plt.figure()
plt.plot(x,acc_list_finetune,lable="finetune")
plt.plot(x,acc_list_fiexd,lable="fiexd")
plt.plot(x,acc_list_finetune,lable="finetune")
plt.title("acc======,lr"+str(lr))
plt.xticks(x)
plt.legend()
plt.savefig("./saved/transfer_acc.jpg")
plt.show()