一、model.train()和model.eval()用法和区别
例如:定义一个网络:
# 线性网络
class Net(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()) # Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。
def forward(self, x):
x = self.gen