废话不多说,直接上代码
import torch
import torch.nn as nn
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 16, 3, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(16, 1, 1, bias=False)
)
for m in self.modules():
if isinstance(m, nn.ReLU):
m.inplace=True
def forward(self, x):
return self.conv3(self.conv2(self.conv1(x)))
if __name__ == '__main__':
m = model()
params = sum(p.numel() for p in m.parameters())
print(params)
# mannually calculate
print(sum([3*32*3*3, 32*2, 32*16*3*3, 16*2, 16*1*1*1]))
结果:
5584
5584