import torch
import torch.nn as nn
# 定义一个空操作
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return x
class Test(nn.Module):
def __init__(self,
parameter1,
parameter2,
parameter3,
parameter4,
parameter5=0.,
parameter6=True):
super().__init__()
# 参数初始化
self.aaa = parameter1
self.bbb = parameter1 / parameter2
# 调用其它类
self.ccc = Identity()
# 定义层
self.fc1 = nn.Linear(parameter3, parameter4)
self.act = nn.ReLU()
# 可以定义多个函数
# 记录每步tensor的变化,并且明白要干嘛
def forward(self, x):
x = self.fc1(x)
# [1,parameter4]
x = self.act(x)
# [1,parameter4]
return x
def main():
t = torch.randint()
# 创建对象
model = Test(parameter1=1,
parameter2=2,
parameter3=3,
parameter4=4,
parameter5=0.,
parameter6=True)
# 打印模型
print(model)
# 将t带入创建的model中输出
out = model(t)
print(out.shape)
if __name__=="__main__":
main()
模型的代码格式
于 2022-03-13 16:39:36 首次发布