这里举例简单的卷积网络,代码如下:
import torch.nn as nn
class myCNN(nn.Module):
def __init__(self,numclass =62,charnum =4):
super(myCNN, self).__init__()#继承
self.numclass = numclass
self.charnum = charnum
self.conv = nn.Sequential(
#这里搭建一层为例,input=3*120*40,output =16*60*20
nn.Conv2d(3,16,2,padding=(1,1)),#卷积,输入通道3,输出通道16,卷积核2*2,边界1填充
nn.MaxPool2d(2,2),#池化
nn.BatchNorm2d(16),#归一化,参数为上一步输出通道数
nn.ReLU()#激活函数
)
self.fc =nn.Linear(16*60*20,self.numclass*self.charnum)
def forward(self, x):
x = self.conv(x)
x = x.view(-1,16*60*20)
x = self.fc(x)
return x