import torch.nn as nn
import torch.nn.functional as F
classLeNet(nn.Module):def__init__(self, classes):super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3,6,5)#cov1
self.conv2 = nn.Conv2d(6,16,5)#cov2
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84, classes)defforward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out,2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out,2)
out = out.view(out.size(0),-1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)#得到分类的概率向量,输入classes为2就是2分类return out
definitialize_weights(self):for m in self.modules():ifisinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)if m.bias isnotNone:
m.bias.data.zero_()elifisinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()elifisinstance(m, nn.Linear):
nn.init.normal_(m.weight.data,0,0.1)
m.bias.data.zero_()