Debug经验总结
一、常规ResBlock的输出尺寸与输入尺寸相同,否则需要进行尺寸变换;
二、在数据集较大时设置num_work进行多线程处理,可以很大提高训练效率;
三、较复杂的网络在搭建前可以先用草图计算每个输出位置的矩阵尺寸,减少Debug难度;
四、选用ReLU激活函数时,应适当降低学习率,避免出现损失函数值无法下降的情况;
五、比较训练集的准确率和测试集的准确率,判断是否出现过拟合。
六、ResBlock是在激活前加入输入值作为偏移量,不能放错位置;
代码展示
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
batch=50
iteration=1
#数据载入
trans=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.15,0.30)])
train_set=datasets.MNIST("D:\桌面\ResNet",train=True,download=True,transform=trans)
train_loader=DataLoader(train_set,batch_size=batch,shuffle=True,num_workers=16)
test_set=datasets.MNIST("D:\桌面\ResNet",train=False,download=True,transform=trans)
test_loader=DataLoader(test_set,batch_size=batch,num_workers=16)
#模块搭建
class ResBlock(torch.nn.Module):
def __init__(self,channels_in):
super().__init__()
self.conv1=torch.nn.Conv2d(channels_in,30,5,padding=2)
self.conv2=torch.nn.Conv2d(30,channels_in,3,padding=1)
def forward(self,x):
out=self.conv1(x)
out=self.conv2(out)
return F.relu(out+x)
#网络搭建
class ResNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1=torch.nn.Conv2d(1,20,5)
self.conv2=torch.nn.Conv2d(20,15,3)
self.maxpool=torch.nn.MaxPool2d(2)
self.resblock1=ResBlock(channels_in=20)
self.resblock2=ResBlock(channels_in=15)
self.full_c=torch.nn.Linear(375,10)
def forward(self,x):
size=x.shape[0]
x=F.relu(self.maxpool(self.conv1(x)))
x=self.resblock1(x)
x=F.relu(self.maxpool(self.conv2(x)))
x=self.resblock2(x)
x=x.view(size,-1)
x=self.full_c(x)
return x
#损失函数、优化器、学习率衰减
model=ResNet()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.005)
schedular=torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.999)
#训练函数
def train():
for epoch in range(iteration):
for batch_index,data in enumerate(train_loader,0):
l=0.0
train_data,train_labels=data
optimizer.zero_grad()
pred_data=model(train_data)
loss=criterion(pred_data,train_labels)
loss.backward()
l+=loss.item()
optimizer.step()
schedular.step()
if batch_index%50==0:
print("epoch:",epoch,"batch_index:",batch_index/50,"loss:",l)
#测试函数
def test():
with torch.no_grad():
correct=0.0
total=0.0
for batch_index,data in enumerate(test_loader,0):
test_data,test_labels=data
pred_data=model(test_data)
_,pred_labels=torch.max(pred_data,dim=1)
total+=test_labels.shape[0]
correct+=(pred_labels==test_labels).sum().item()
if batch_index%20==0:
print("测试进度:",100.0*batch_index/200,"%")
print("准确率为:",correct*100.0/total,"%")
#主函数
if __name__ == '__main__':
train()
test()