[PyTorch][chapter 50][自定义网络 ResNet18]

前言:

        这里结合一个ResNet-18 网络,讲解一下自己定义一个深度学习网络的完整流程。

经过20轮的训练,测试集上面的精度85%

一   残差块定义

针对图像处理有两种结构,下面代码左右实现的是左边的结构.

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 12:00:57 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch.nn import functional as F

class ResBlk(nn.Module):
  
    """
    resnet block
    """
    def __init__(self, in_ch, out_ch, step):
       
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = in_ch,
                               out_channels = out_ch,
                               kernel_size =3,
                               stride =step,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(in_channels = out_ch,
                               out_channels = out_ch,
                               kernel_size =3,
                               stride =1,
                               padding=1)
        
        self.bn2 = nn.BatchNorm2d(out_ch)
        
        self.extra = nn.Sequential()
        
        #残差块部分
        if in_ch != out_ch:
            
            self.extra = nn.Sequential(
                #[b,in_ch, h,w]=>[b, out_ch, h,w]
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride = step),
                nn.BatchNorm2d(out_ch)
                )
        
    def forward(self,x):
        
        """
        param x: [b ,ch, h,w]
        return 
        """
        
        print(x.shape)
        
        conv = self.conv1(x)
        bn1 = self.bn1(conv)
        out = F.relu(bn1)

        
        conv = self.conv2(out)
        bn2 = self.bn2(conv)
        out = F.relu(bn2)

        out = self.extra(x)+out
        out = F.relu(out)
        
        return out
    


    
    
        
        

      


二 定义网络

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 14:22:34 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch.nn import functional as F
from ResBlock import ResBlk


class ResNet18(nn.Module):
    
    def __init__(self, num_class):
        
        super(ResNet18, self).__init__()
        
        conv = nn.Conv2d(in_channels = 3,
                               out_channels = 16,
                               kernel_size =3,
                               stride =2,
                               padding=0)
        bn = nn.BatchNorm2d(16)
        
        self.conv1 = nn.Sequential(conv, bn)
        
        #followed 4 blocks
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk1 = ResBlk(16, 32, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk2 = ResBlk(32, 64, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk3 = ResBlk(64, 128, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk4 = ResBlk(128, 256, 3)
        
     
        
        self.fc = nn.Linear(256*2*2, num_class) 
        
    def forward(self, x):
        
        
        a = self.conv1(x)
        a = F.relu(a)
        print("\n a ",a.shape)
        a = self.blk1(a)
        a = self.blk2(a)
        a = self.blk3(a)
        a = self.blk4(a)
        
        #print(x.shape)
        print("\n fc a: ",a.shape)
        a = a.view(a.size(0),-1) #Flatten
        y = self.fc(a)
        
        return y
    

def main():
    
    blk = ResBlk(64, 128,2)
    #tmp: [batch, channel, width, height]
    tmp = torch.randn(2,64,224,224)
    out = blk(tmp)
    print("\n resBlock: ",out.shape)
    
    
    model =ResNet18(5)
    
    tmp = torch.randn(2,3,224,224)
    
    out = model(tmp)
    
    print("resnet-18 ",out.shape)
    
    #numbel是指tensor占用内存的数量
 
    mp =map(lambda p:p.numel(),  model.parameters())
    sz = sum(mp)
    print("\n parameters size ",sz)
   

if __name__ == "__main__":
    
     main() 
        
        
    
        
        
        

三 Train& Test

   逻辑如下:

   先使用训练集数据训练

    使用验证集数据过拟合检查,保存模型参数

    加载模型参数,进行测试

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:28:13 2023

@author: chengxf2
"""



for epoch in range(epochs):
    
    train(train_db)
    
    if epoch %10 ==0:
        
        val_acc = evaluate(val_db)
        
        if val_ass is the best:
            #报错模型参数,防止过拟合
            save_ckpt()
        
        if out_of_patience():
            
            break
#加载模型参数        
load_ckpt()

test_acc = evaluate(test_db)

四 训练,验证,测试部分完整代码

  

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from ResNet_18 import ResNet18
from PokeDataset import Pokemon

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值