使用self.training,这样就可以让forward函数采用两种执行方式,然后就可以做一些骚操作了
import torch.nn as nn
class myNet(nn.Module):
def __init__(self):
super(myNet, self).__init__()
def forward(self):
if self.training:
print('training')
else:
print('not training')
model = myNet()
model.train()
result = model()
model.eval()
result = model()