import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary as summary
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0, 0.1)
m.bias.data.zero_()
def lenet():
return LeNet(classes=2)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = lenet().to(device)
net.initialize_weights()
print(net)
summary(net, input_size=(3, 32, 32))
解决办法:
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = lenet().to(device)
net.initialize_weights()
print(net)
summary.summary(net, input_size=(3, 32, 32))