import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out,stride = 1):
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out !=ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self,x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut
out = self.extra(x) + out
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(64)
)
# flowed 4 blocks
self.blk1 = ResBlk(64,128,stride=2)
self.blk2 = ResBlk(128,256,stride=2)
self.blk3 = ResBlk(256,512,stride=2)
self.blk4 = ResBlk(512,512,stride=2)
self.outlayer = nn.Linear(512*1*1,10)
def forward(self,x):
x = self.conv1(x)
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print("after conv",x.shape)
# [b,512,h,w] ==> [b,512,1,1]
x = F.adaptive_avg_pool2d(x,[1,1])
# print("after pooling",x.shape)
x = x.view(x.size(0),-1)
return self.outlayer(x)
if __name__ == '__main__':
blk = ResBlk(64,128,2)
tmp = torch.randn(2,64,32,32)
out = blk(tmp)
print(out.shape)
x = torch.randn(2,3,32,32)
model = ResNet18()
out = model(x)
print(out.shape)
在下面代码中训练
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x:[b,3,32,32] === >[b,6, , ]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
)
# flatten
# fc unit
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# 使用交叉熵损失 softmax和loss操作统一给这个函数了
# self.crition = nn.CrossEntropyLoss()
# x:[b,3,32,32]
# tmp = torch.randn(2,3,32,32)
# out = self.conv_unit(tmp)
# print(out.shape)
def forward(self,x):
# [b,3,32,32,] => [b,16,5,5]
batchsz = x.size(0)
x = self.conv_unit(x)
# [b,16,5,5] =>[b,16 * 5 * 5]
x= x.view(batchsz,-1)
# [b, 16 * 5 * 5] => [b,10]
logits = self.fc_unit(x)
return logits
from torchvision import transforms
from torch.utils.data import DataLoader
from ResBlk import ResNet18
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose(
[
transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
]
),download=True,)
cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose(
[
transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]
),download=True,)
cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=False)
x,label = iter(cifar_train).next()
print("x",x.shape," label",label.shape)
device = torch.device("cuda")
# net = Lenet5()
net = ResNet18()
net.to(device)
# 损失函数 包含softmax和 loss操作
crition = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)
for epoch in range(1000):
net.train()
for batchidx,(x,label) in enumerate(cifar_train):
x,label = x.to(device),label.to(device)
# x: [b,10] label:[b]
x = net(x)
loss = crition(x,label)
# 梯度清零
optimizer.zero_grad()
# 计算梯度值
loss.backward()
# 将梯度值代入损失函数进行计算,并往下迭代一步 更新到weight
optimizer.step()
# 转化为numpy
print(epoch,loss.item())
# test
net.eval()
with torch.no_grad():
total_correct = 0
total_num = 0
for x,label in cifar_test:
x,label = x.to(device),label.to(device)
# [b,10]
logits = net(x)
# [b]
pred = logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
print("epoch ",epoch," acc ",total_correct/total_num)
torch.save(net,'result.pt')
if __name__ == '__main__':
main()
# net = Lenet5()
# tmp = torch.randn(2,3,32,32)
# out = net(tmp)
# # [2,10]
# print(out.shape)