一、nn.Module的使用
在Pytorch的首页查看torch.nn
查看container中的module
在重写__init__( )方法时一定要注意继承父类的__init__方法。
示例:
from torch import nn
class Cow(nn.Module):
# 重写下面两个方法
def __init__(self):
super().__init__()
要么也可以点击pycharm中的code ->generate ->override
结果如下:
class Cow(nn.Module):
def __init__(self) -> None:
super().__init__()
重写完两个方法后,再运行如下代码:
import torch
from torch import nn
class Cow(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
output = input + 1
return output
cow = Cow()
x = torch.tensor(1.0)
output = cow(x)
print(output)
结果如下:
tensor(2.)
再对代码进行调试: