1、、、第一种
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class LeNet(nn.Module):
def __init__(self,classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) #(32+2*0-5)/1+1 = 28
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
def forward(self, x): #4 3 32 32 ->nn.Conv2d(3, 6, 5)-> 4 6 28 28
out = F.relu(self.conv1(x)) #32->28 4 6 28 28
out = F.max_pool2d(out, 2) #4 6 14 14
out = F.relu(self.conv2(out)) # 4 16 10 10
out = F.max_pool2d(out, 2) # 4 16 5 5
out = out.view(out.size(0), -1) #4 400
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0, 0.1)
m.bias.data.zero_()
net = LeNet(classes=3)
fake_img = torch.randn((4, 3, 32, 32), dtype=torch.float32)
output = net(fake_img)
print('over!')
2、、第二种;Sequential
顺序型,各网络层之间严格按照顺序执行,常用于block构建。 Sequential的前向传播不用自己写,因为是顺序的,不用写forward
# ============================ Sequential
class LeNetSequential(nn.Module):
def __init__(self, classes):
super(LeNetSequential, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes),)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
3、、第三种;ModuleList
迭代型,常用于大量重复网络构建,通过for循环实现重复构建。 ModuleList 传进去的是列表,构造大量重复的片段,需要手动写前向传播,需要在forward中写。
# ============================ ModuleList
class myModuleList(nn.Module):
def __init__(self):
super(myModuleList, self).__init__()
modullist_temp = [nn.Linear(10, 10) for i in range(20)]
self.linears = nn.ModuleList(modullist_temp)
def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x
net = myModuleList()
#
# print(net)
#
fake_data = torch.ones((10, 10))
#
output = net(fake_data)
#
print(output)