深度学习杂记
关于深度学习的一些杂记
Eloik
人工智能、大数据
展开
-
Pytorch—保存检查点代码
class ModelCheckpoint(object): def __init__(self, filepath: str = 'checkpoint.pth', monitor: str = 'val_loss', mode: str = 'min', save_best_only: bool = False, save_freq: int = 1): """ :param filepath: 文件名或文件夹名,需要保存的位置,原创 2021-09-07 14:24:27 · 629 阅读 · 1 评论 -
Pytorch—提前终止代码
前言以前使用keras的时候有一个很方便的提前终止类,而pytorch每次都要自己写一次,因此我整理了一个简单通用的代码,需要提前终止功能时,只需cv一下,避免了每次重复写的麻烦。代码class EarlyStopping(object): def __init__(self, monitor: str = 'val_loss', mode: str = 'min', patience: int = 1): """ :param monitor: 要监测的指标,原创 2021-09-06 14:39:05 · 3588 阅读 · 1 评论 -
Pytorch—如何保存训练好的模型
Pytorch 保存和加载训练好的模型保存整个模型torch.save(model,'model.pkl')加载模型(不需要再次定义模型架构)model = torch.load('model.pkl')只保存参数torch.save(model.state_dict(),'model_param.pkl')加载参数(必须保持模型架构不变)# 加载参数model_param = torch.load('model_param.pkl')# 为模型设置参数model原创 2021-07-08 23:53:15 · 2481 阅读 · 1 评论 -
Pytorch—如何进行网络参数初始化
Pytorch网络参数初始化的方法常用的参数初始化方法方法(均省略前缀 torch.nn.init.)功能uniform_(tensor, a=0.0, b=1.0)从均匀分布 U(a,b) 中生成值,填充输入的张量normal_(tensor, mean=0.0, std=1.0)从给定均值 mean 和标准差 std 的正态分布中生成值,填充输入的张量constant_(tensor, val)用 val 的值填充输入的张量ones_(tensor)原创 2021-07-08 23:27:32 · 8778 阅读 · 7 评论