Pytorch构建模型的技巧小结(1)

####1、保存模型参数和加载模型参数

(A)保存参数

# 2 ways to save the net
torch.save(net1, 'net.pkl')  # save entire net
torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters

(B)加载参数

# copy net1's parameters into net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)

上面出现的net1和net3都是nn.Module的实例。

####2、模型参数的钳位

# Clip weights of discriminator
for p in discriminator.parameters():
      p.data.clamp_(-opt.clip_value, opt.clip_value)

p是Module(nn.Module)—— discriminator的参数。这段代码是实现WGAN时用到的。钳位不仅可以实现WGAN,而且它可以消除在训练中出现的nan情况,但钳位的大小很关键。

####3、模型的CUDA化
在配有CUDA的训练过程中,模型和数据都需要加载到CUDA中,pytorch的张量有两种类型,以Float为例:用于CPU——torch.FloatTensor、用于CUDA——torch.cuda.FloatTensor,以下是完整列表:
n CPU CUDA Desc. 1 torch.FloatTensor torch.cuda.FloatTensor 32-bit floating point 2 torch.DoubleTensor torch.cuda.DoubleTensor 64-bit floating point 3 N/A torch.cuda.HalfTensor 16-bit floating point 4 torch.ByteTensor torch.cuda.ByteTensor 8-bit integer (unsigned) 5 torch.CharTensor torch.cuda.CharTensor 8-bit integer (signed) 6 torch.ShortTensor torch.cuda.ShortTensor 16-bit integer (signed) 7 torch.IntTensor torch.cuda.IntTensor 32-bit integer (signed) 8 torch.LongTensor torch.cuda.LongTensor 64-bit integer (signed) \begin{array}{c|lc|r} n & \text{CPU} & \text{CUDA} & \text{Desc.}\\ \hline 1 & \text{torch.FloatTensor} & \text{torch.cuda.FloatTensor} & \text{32-bit floating point} \\ 2 & \text{torch.DoubleTensor} & \text{torch.cuda.DoubleTensor} & \text{64-bit floating point} \\ 3 & \text{N/A} & \text{torch.cuda.HalfTensor} & \text{16-bit floating point} \\ 4 & \text{torch.ByteTensor} & \text{torch.cuda.ByteTensor} & \text{8-bit integer (unsigned)} \\ 5 & \text{torch.CharTensor} & \text{torch.cuda.CharTensor} & \text{8-bit integer (s

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值