手动指定pytorch模型权重值

def adjust_net(order):
    net = torch.nn.Sequential(
        torch.nn.Linear(371, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 512),
        torch.nn.LeakyReLU(0.01),
        torch.nn.Linear(512, 371),
    )

    net.load_state_dict(
        torch.load(f'ugb-embedded-230504/ugb_ckpt/model_bidding_sl/{order:02d}.ckpt', map_location='cpu')
    )

    left = order if order > 0 else 1
    right = 8 / 13

    torch.set_grad_enabled(False)

    for layer in net.modules():
        if isinstance(layer, torch.nn.Linear):
            print(order, layer)
            for weight in layer.parameters():
                for j in range(512):
                    for i in range(371):
                        if i < 319:
                            weight[j, i] /= left
                        else:
                            weight[j, i] /= right
                torch.set_grad_enabled(True)
                torch.save(
                    net.state_dict(), f'ugb-embedded-230504/ugb_ckpt/model_bidding_rl/{order:02d}.ckpt'
                )
                return

for order in range(15):
    adjust_net(order)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值