用resnet18的结构
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
'''
resnet block
'''
def __init__(self,ch_in,ch_out,stride=1):
'''
:param ch_in:
:param ch_out:
'''
super(ResBlk,self).__init__()
self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out) # ResNet,一般都会加BatchNorm
self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2=nn.BatchNorm2d(ch_out)
self.extra=nn.Sequential() # 先设为空的extra,如果有的话,下面的就把这个覆盖掉
if ch_out != ch_in:
#把[b,ch_in,h.w]=>[b,ch_out,h,w]
self.extra =nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride), # 保持size不变