本文将预训练的resnet18网络,使用少量pokemon数据集实现迁移学习,在此过程中使用visdom进行数据集和训练过程的可视化。本文代码主要分为两部分:1.加载自定义数据集(数据预处理,给对应类定义标签);2.迁移学习。
import torch
import torch.nn as nn
from torch.nn import functional as F
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root = root
self.resize = resize
self.name2label = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
#image,label
self.images,self.labels = self.load_csv('images.csv')
#数据集划分
if mode == 'train':#60%
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
if mode == 'val':#60%-80%
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else:#80%-100%
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root,name,'*.jpg'))
images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
print(len(images),images)
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images: #pokemon\\bulbasaur\\00000000.png
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img,label])
print('writen into csv file:',filename)
#read from csv file
images , labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean = [0.485,0.456,0.406]
std = [0.229,0.224,0.225]
#x_hat = (x-mean)/std
#x = x_hat*std+mean
#x:[c,h,w]
#mean:[3]=>[3,1,1] broadcasting
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def __getitem__(self,idx):
img,label = self.images[idx],self.labels[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),#string path->image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
img = tf(img)
label = torch.tensor(label)
return img,label
def main():
import visdom
import time
viz = visdom.Visdom()
db = Pokemon('pokemon',224,'train')
x,y = next(iter(db))
print('sample',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
loader = DataLoader(db,batch_size=32,shuffle=True,num_workers=0)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
time.sleep(10)
if __name__ =='__main__':
main()
#由于pytorch没有Flatten功能,因此先手写一个Flatten层
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1,shape)
#迁移学习
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from torchvision.models import resnet18
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cpu')
torch.manual_seed(1234)
train_db = Pokemon('pokemon',224,'train')
val_db = Pokemon('pokemon',224,'val')
test_db = Pokemon('pokemon',224,'test')
train_loader = DataLoader(train_db,batch_size=batchsz,shuffle=True)
val_loader = DataLoader(val_db,batch_size=batchsz)
test_loader = DataLoader(test_db,batch_size=batchsz)
viz = visdom.Visdom()
def evaluate(model,loader):
correct = 0
total = len(loader.dataset)
for x,y in loader:
x,y = x.to(device),y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred,y).sum().float().item()
return correct/total
def main():
#model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(),
nn.Linear(512,5)
).to(device)
#x = torch.randn(2,3,224,224)
#print(model(x).shape)
optimizer = optim.Adam(model.parameters(),lr=lr)
criterion = nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step = 0
viz.line([0],[-1],win='loss',opts=dict(title='loss'))
viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
#x:[b,3,224,224],y[b]
x,y = x.to(device),y.to(device)
logits = model(x)
loss = criterion(logits,y) #会自动做one-hot
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()],[global_step],win='loss',update='append')
global_step += 1
if epoch %1 ==0:
val_acc = evaluate(model,val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(),'best.mdl')
viz.line([val_acc],[global_step],win='val_acc',update='append')
print('best acc:',best_acc,'best epoch:',best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evalute(model,test_loader)
print('test_acc:',test_acc)
if __name__ == '__main__':
main()