PL的流程
很简单,生产流水线,有一个固定的顺序:
初始化 def init(self) -->
训练 def training_step(self, batch, batch_idx) --> training_step_end(self,batch_parts) --> training_epoch_end(self, training_outputs)
校验 def validation_step(self, batch, batch_idx) --> …
测试 def test_step(self, batch, batch_idx) --> …
class MyModule(pl.LightningModule):
def __init__(self):
self.loss = ...
def forward(self, x, y):
# write my model layers...
# 可以加上dropout,残差
...
return out
def configure_optimizers(self):
return torch.optim.Adam(self.parameters()) #我用Adam, 可以替换其他优化器
def training_step(self, batch, batch_idx):
x, y = batch;
y_hat = self(x); #forward
loss = self.loss(y_hat, y) #计算loss