class ResNetBig(nn.Module):
def __init__(self, classes, structure, **kwargs):
super().__init__()
self.initialize = nn.Sequential(
ConvBnRelu(3, 64, 7, 2, 3),
nn.MaxPool2d(3, 2, padding=1)
)
self.body = nn.Sequential()
for i, layer_num in enumerate(structure):
if i == 0:
for j in range(layer_num):
if j == 0:
self.body.append(ResCell2(64, 64, 256))
else:
self.body.append(ResCell2(256, 64, 256))
else:
for j in range(layer_num):
if j == 0:
in_channel = np.power(2, (i + 7))
exp_channel = np.power(2, (i + 6))
out_channel = np.power(2, (i + 8))
print(in_channel, exp_channel, out_channel)
print('___________________________________')
self.body.append(ResCell1(in_channel, exp_channel, out_channel))
else:
in_channel = np.power(2, (i + 8))
exp_channel = np.power(2, (i + 6))
out_channel = np.power(2, (i + 8))
print(in_channel, exp_channel, out_channel)
print('___________________________________')
self.body.append(ResCell2(in_channel, exp_channel, out_channel))
self.top = nn.Sequential(
nn.MaxPool2d(7),
nn.Flatten(),
nn.Linear(2048, classes),
nn.Softmax()
)
def forward(self, inputs):
x = self.initialize(inputs)
x = self.body(x)
x = self.top(x)
return x
RES_NET通过pytorch创建
最新推荐文章于 2024-10-08 17:00:04 发布