import torch
import torch.nn as nn
classmy_module(nn.Module):def__init__(self,*args):super(my_module, self).__init__()for block in args:
self._modules[block]= block
# self.linear = nn.Linear(2,3)# self.relu = nn.ReLU()defforward(self,x):# h1 = self.linear(x)# h2 = self.relu(h1)for block in self._modules.values():
x = block(x)print((self._modules))return x
x = my_module(nn.Linear(5,10),nn.ReLU(),nn.Linear(10,3))
y = x(torch.rand(1,5))print(y)# print({'k':3}.values())
知识点
继承nn.Module ,属性self._modules是一个字典
在定义函数时*args代表可传入参数的个数是随意的。
实参按逗号分开计数,如果一个列表不加*,按一个元素算;如果加*号,就按列表中的元素为多参数。
#加*号defdemo(*args):for arg in args:print(type(arg))
demo(*[1,2,3])
out[1]:<class'int'><class'int'><class'int'>#不加*号defdemo(*args):for arg in args:print(type(arg))
demo([1,2,3])
out[2]<class'list'>