PyTorch常用代码段

这篇博客总结了PyTorch中的实用代码段,包括导入包与版本查询、确保可复现性、管理GPU、将模型从GPU转移到CPU、调整模型结构、加载模型参数、训练与测试模型的流程,以及保存和加载模型的断点。内容涵盖了深度学习模型的基本操作和优化技巧。
摘要由CSDN通过智能技术生成

[深度学习框架]PyTorch常用代码段
文中有很多其他的trips,这里记录一下我需要的

1、导入包和版本查询

导入包和版本查询

  • torch版本
  • CUDA版本
  • cuDNN版本
  • CUDA设备名字

note:

print("Python Version:", sys.version)	# python版本
print("pytorch version:", torch.__version__)	# pytorch版本
print("CUDA Version:", torch.version.cuda)	# cuda版本
print("cuDnn Version:", torch.backends.cudnn.version())	# cudnn版本
print("Devices:", device)	# 使用的设备cuda or cpu

2、可复现性

可复现性

note:

np.random.seed(0)
torch.manual_seed(0)	 # 设置 (CPU) 生成随机数的种子
torch.cuda.manual_seed_all(0)	# 为所有GPU设置种子

# cudnn卷积操作进行了优化,牺牲了精度(小数点后几位)来换取计算效率
# 令benchmark=False就没有优化,保证了精度,但是牺牲了计算效率
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# torch.cuda.manual_seed(0)	# 为当前GPU设置种子(单GPU)

部分解释来自:seed() 和torch中manual_seed的作用

3、显卡设置

显卡设置

  • 单卡
  • 多卡
  • 清除显存

4、将在 GPU 保存的模型加载到 CPU

model.load_state_dict(torch.load('model.pth', map_location='cpu'))

改实例化模型后,如何修改某些层?

num_ftrs=model_ft.head.in_features
model_ft.head=nn.Linear(num_ftrs,7,bias=True)#修改最后的线性层

即实例化一个新的module组件去调换原组件,不能通过model_ft.head.out_features = 7去指定,不起作用。因为该属性是为实例化提供参数,现在实例化完成了,再去修改该属性并不影响已经实例化完成的网络结构。所以必须直接用另一个重新实例化后的结构去替换原模型结构。

children和modules的区别

5、加载其他模型相同的层参数

导入另一个模型的相同部分到新的模型

pretrain_state_dict = torch.load(path)
resnet18_statd_dict = net.state_dict()

# 忽略key不同名的层:strice指明忽略不同名的层
net.load_state_dict(update_dict, strict=False)

# 忽略层同名但不同形状:比如修改了最后一个全接层的数
update_dict = {k: v for k, v in pretrain_state_dict.items() if (k in resnet18_statd_dict.keys() and resnet18_statd_dict[k].shape==v.shape)}
net.load_state_dict(update_dict, strict=False)

6、一个分类模型的训练和测试代码,可以学习下一些细节的写法

分类模型

7、保存和加载断点

保存与加载断点

  • 模型的保存和加载
  • 优化器状态的保存和加载
  • epoch的保存和加载

note:

def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimize_state_dict": optimizer.state_dict(),
}, path)

def load_checkpoint(path):
    checkpoint = torch.load(path)
    return checkpoint
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值