1 模型类
1.1 init()
1.2 forward()
2 怎么使用
2.1 模型类初始化model = Model_Class(args...),这里的参数是模型参数,不是具体的数据
2.2 model.train()
model.train():
在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是启用batch normalization和drop out。
model.eval():
测试过程中会使用model.eval(),这时神经网络会沿用batch normalization的值,并不使用drop out。
torch.no_grad():
而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用,具体行为就是停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。
2.3 model(具体数据),参数是按照forward()里面参数来传的,这个时候就会调用model.forward()