网络
加入了残差
from torch import nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.channels = channels
self.conv1 = nn.Conv2d(channels, channels, 3,1,1)
self.conv2 = nn.Conv2d(channels, channels, 3,1,1)
def forward(self, x):
y = F.relu(self.conv1(x))
y = self.conv2(y)
return F.relu(x + y)
class RNet(nn.Module):
def __init__(self):
super(RNet, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2),
ResidualBlock(32),
nn.Conv2d(32,32,5,1,2) ,
nn.ReLU(),
nn.MaxPool2d(2),
ResidualBlock(32),
nn.Dropout(0.25),
nn.Conv2d(32,64,5,1,2),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.5),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,x):
x = self.model(x)
return x
训练
import torchvision
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from ResNetMo import *
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_dataloader = DataLoader(train_data, batch_size=32)
test_dataloader = DataLoader(test_data, batch_size=32)
rnet = RNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(rnet.parameters(), lr=0.01)
total_train_step = 0
total_test_step = 0
epoch = 200
rnet.train()
min_loss = 0x3f3f3f3f
for i in range(epoch):
print(f'--------第{i+1}轮训练开始--------')
for data in train_dataloader:
imgs,targets = data
outputs = rnet(imgs)
loss = loss_fn(outputs,targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 100 == 0:
print(f"训练次数:{total_train_step},Loss:{loss.item()}")
total_test_loss = 0
with torch.no_grad():
for data in test_dataloader:
imgs,targets = data
outputs = rnet(imgs)
loss = loss_fn(outputs,targets)
total_test_loss = total_test_loss + loss.item()
print("整体测试集上的Loss:",total_test_loss)
if total_test_loss < min_loss:
torch.save(rnet, "resNet.pth")
min_loss = total_test_loss
total_test_step += 1
print(min_loss)
# tensorboard --logdir logs