模型保存和加载
保存模型parameter、buffer
path = '/kaggle/working/state_dict_model.pt'
torch.save(model.state_dict(), path)
n1_model = Model()
n1_model.load_state_dict(torch.load(path))
n1_model.eval()
'''
Model(
(fc): Linear(in_features=768, out_features=2, bias=True)
)
'''
保存整个模型
path = '/kaggle/working/entire_model.pt'
torch.save(model, path)
n2_model = torch.load(path)
n2_model.eval()
'''
Model(
(fc): Linear(in_features=768, out_features=2, bias=True)
)
'''
checkpoint 保存和加载
epoch = 5
loss = 0.4
path = '/kaggle/working/5_0.4_checkpoint.pt'
torch.save({
'epoch': epoch
,'loss': loss
,'model_state_dict': model.state_dict()
,'optimizer_state_dict': optimizer.state_dict()
,
}, path)
n3_model = Model()
n3_optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
checkpoint = torch.load(path)
epoch = checkpoint['epoch']
loss = checkpoint['loss']
n3_model.load_state_dict(checkpoint['model_state_dict'])
n3_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
n3_model.eval()
n3_model.train()
import os
for dirname, _, filenames in os.walk('/kaggle/'):
for filename in filenames:
print(os.path.join(dirname, filename))
'''
/kaggle/lib/kaggle/gcp.py
/kaggle/input/chnsenticorp/ChnSentiCorp/dataset_info.json
/kaggle/input/chnsenticorp/ChnSentiCorp/ChnSentiCorp.py
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-train.arrow
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-test.arrow
/kaggle/input/chnsenticorp/ChnSentiCorp/chn_senti_corp-validation.arrow
/kaggle/working/state_dict_model.pt
/kaggle/working/__notebook_source__.ipynb
/kaggle/working/5_0.4_checkpoint.pt
/kaggle/working/entire_model.pt
'''