torchvision中提供了常见的模型如 resnet mobilenet等,或者可以直接把想要的残差结构、深度可分离卷积结构粘贴出来。这里给个样例,以最后为9分类为例
import torch
import torch.nn as nn
from torchvision import models
# 这里可以进去源码看models里支持的模型
# 或者 dir(models) 查看
class resnet18(nn.Module):
def __init__(self,c=9) -> None:
super().__init__()
bkbn = models.resnet18()
self.bkbn = torch.nn.Sequential( *( list(bkbn.children())[:-2] ) )
#print(self.bkbn)
self.cls = nn.Sequential(
nn.Flatten(),
nn.Linear(4*4*512,512),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(512,c)
)
def forward(self,x):
feature = self.bkbn(x)
#print(feature.size())
classes = self.cls(feature)
return classes
if __name__ == '__main__':
net = resnet18()
x = torch.rand((1,3,128,128),dtype=torch.float32)
out = net(x)
print(out.size())
训练脚本
import torch
import torch.nn as nn
import torch.optim as optim
from data import PData
from model import resnet18
net = resnet18()
#print(net)
pretrain = None
if pretrain:
net.load_state_dict(torch.load(pretrain))
device = torch.device('cuda:1')
net = net.to(device)
lr = 0.001
momentum = 0.9
batch_size = 64
num_workers = 1
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
dataset = PData()
train_loader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers)
epochs = 600
for epoch in range(epochs):
if (epoch % 100 == 0 and epoch > 0) or (epoch % 50 == 0 and epoch > 30000):
torch.save(net.state_dict(), 'checkpoint/epoch_'+str(epoch) + '.pth')
for step,data in enumerate(train_loader):
# load train data
images, targets = data
#print(images.size(),targets.size())
#print(targets)
images = images.to(device)
targets = targets.to(device)
# forward
out = net(images)
#print(out.size())
# backprop
optimizer.zero_grad()
loss = criterion(out, targets)
loss.backward()
optimizer.step()
if step % 3 == 0:
print(f'epoch :{epoch} step:{step} loss:{loss}\n')
数据文件
import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
import random
class PData(data.Dataset):
def __init__(self, platepath='province_train.txt'):
self.img_paths = []
with open(platepath,'r') as f:
for line in f:
self.img_paths.append(line.strip())
self.label_dict = {'ah':0,'fj':1,'gz':2,'hlj':3,'hn':4,'js':5,'nx':6,'xa':7,'xz':8}
def to_norm(self,image):
image = image.astype(np.float32)
image /= 256
return image.transpose(2, 0, 1)
def __getitem__(self, index):
img_path= self.img_paths[index]
province = img_path.split('/')[1]
label = self.label_dict[province]
img = cv2.imread(img_path)
img = self.to_norm(img)
img = torch.from_numpy(img)
label = torch.tensor(label,dtype=torch.long)
return img, label
def __len__(self):
return len(self.img_paths)
if __name__ == '__main__':
pld = PData()
img,t = pld.__getitem__(190)
#print(img.size())
print(img.size())
print(t)