1、自定义数据集实战
step1: load data
torch.util.data.Dataset
__ len __:数量
__ getitem __:返回样本
数据预处理
- image resize
- data argumentation:rotate, crop
- normalize :mean,std
- totensor
将名称存入字典
加载每张图片的地址
load_csv
将images 和label分别存入数据集和标签集
step2:build model
step 3:train and test
step 4:transfer learning: trainer的初始化
train_scratch 代码:
import torch
from torch import optim, nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet18 import ResNet18
batchsize=64
lr=1e-3
epochs=10
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",64,mode="train")
val_db=Pokemon("pokemon",64,mode="validation")
test_db=Pokemon("pokemon",64,mode="test")
train_loader = DataLoader(train_db,batch_size=batchsize,shuffle=True,num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsize,shuffle=True,num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsize,shuffle=True,num_workers=2)
viz=visdom.Visdom()
def evaluate(model,loader):
model.eval()
correct=0
total=len(loader.dataset)
for x, y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criterion=nn.CrossEntropyLoss().to(device)
best_acc,best_epoch=0,0
global_step=0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
model.train()
logits=model(x)
loss=criterion(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch%1==0:
val_acc=evaluate(model,val_loader)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
#model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)
if __name__ == '__main__':
main()
ResNet代码:
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out,stride=1):
super(ResBlk, self).__init__()
self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out)
self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2=nn.BatchNorm2d(ch_out)
self.extra=nn.Sequential()
self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out))
if ch_in!=ch_out:
self.extra=nn.Sequential(nn.