刘二大人 ResNet with Pytorch代码实现

ResNet代码实现


步骤

引入库

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

Prepare dataset

batch_size = 64
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,),(0.3081,))
])
 
train_dataset=datasets.MNIST(root="../dataset/mnist",
                             train=True,
                             download=False,
                             transform=transform)
train_loader=DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True)
 
test_dataset=datasets.MNIST(root='../dataset/mnist/',
                            train=False,
                            download=False,
                            transform=transform)
test_loader=DataLoader(dataset=test_dataset,
                       shuffle=False)

Design model using Class

class residual(nn.Module):
    #需要保证输出和输入通道数x一样
    def __init__(self,channels):
        super(residual,self).__init__()
        self.channels=channels
        #3*3卷积核,保证图像大小不变将padding设为1
        #第一个卷积
        self.conv1=nn.Conv2d(channels,channels,
                             kernel_size=3,padding=1)
        #第二个卷积
        self.conv2=nn.Conv2d(channels,channels,
                             kernel_size=3,padding=1)
 
    def forward(self,x):
            #激活
            y=F.relu(self.conv1(x))
            y=self.conv2(y)
            #先求和 后激活
            return F.relu(x+y)
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=1)
        self.mp=nn.MaxPool2d(2)
 
 
        self.residual1 = residual(16)
        self.residual2 = residual(32)
        self.residual3 = residual(64)
 
        self.fc=nn.Linear(256,10)
 
    def forward(self,x):
        batch_size=x.size(0)
        
        x=self.mp(F.relu(self.conv1(x)))
        x=self.residual1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.residual2(x)
        
        x = self.mp(F.relu(self.conv3(x)))
        x = self.residual3(x)
        
        x=x.view(batch_size,-1)
        x=self.fc(x)
        
        return x
 
net=Net()
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

Construct loss and optimizer

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(net.parameters(),lr=0.01)

Training cycle

def train(epoch):
    running_loss=0.0
    for batch_idx,data in enumerate(train_loader,0):
        inputs,targets=data
        inputs,targets=inputs.to(device),targets.to(device)
        optimizer.zero_grad()
        #forward
        y_pred=net(inputs)
        #backward
        loss=criterion(y_pred,targets)
        loss.backward()
        #update
        optimizer.step()
 
 
        running_loss+=loss.item()
        if(batch_idx%300==299):
            print("[%d,%d]loss:%.3f"%(epoch+1,batch_idx+1,running_loss/300))
            running_loss=0.0
 
accuracy=[]
 
def test():
    correct=0
    total=0
 
    with torch.no_grad():
        for data in test_loader:
            images,labels=data
            images,labels = images.to(device), labels.to(device)
            outputs=net(images)
            _,predicted=torch.max(outputs.data,dim=1)
            total+=labels.size(0)
            correct+=(labels==predicted).sum().item()
 
        print("accuracy on test set:%d %% [%d/%d]"%(100*correct/total,correct,total))
        accuracy.append(100 * correct / total)
 if __name__=="__main__":
    for epoch in range(2):
        train(epoch)
        test()
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值