self.add_module()
待补充
from torch import nn
class TestNet(nn.Module):
def __init__(self, nheads):
super(TestNet, self).__init__()
self.layers = [nn.Linear(2,2) for _ in range(nheads)]
for i, layer in enumerate(self.layers):
self.add_module('layer_{}'.format(i), layer)
self.layer_2 = nn.Linear(5,2)
def forward(self, x):
x = torch.cat([att(x) for att in self.layers], dim=1)
return x
mul_Linear = TestNet(2)
mul_Linear
输出:
TestNet(
(layer_0): Linear(in_features=2, out_features=2, bias=True)
(layer_1): Linear(in_features=2, out_features=2, bias=True)
(layer_2): Linear(in_features=5, out_features=2, bias=True)
)