ptan实战5 || TargetNet类

ptan实战5 || TargetNet类

TargetNet类的成员包含.model.target_model,前者是NN的一个引用,共享空间,后者是将NN复制过来,是一个独立的网络

方法:

sync()用于将NN的参数同步给target_net

alpha_sync()源NN的参数以alpha的权重赋值给target_model,w=aw + (1-a)s ,w是target_model的权重,s是源NN的权重

import ptan
import torch.nn as nn

class DQNNet(nn.Module):
    def __init__(self):
        super(DQNNet, self).__init__()
        self.ff = nn.Linear(5, 3)

    def forward(self, x):
        return self.ff(x)

if __name__ == "__main__":
    net = DQNNet()
    print(net)
    tgt_net = ptan.agent.TargetNet(net)     # 初始化ptan.agent.TargetNet类,需要传入模型NN
    print("Main net:", net.ff.weight)
    print("Target net:", tgt_net.target_model.ff.weight)
    net.ff.weight.data += 1.0
    print("After update")
    print("Main net:", net.ff.weight)
    print("Target net:", tgt_net.target_model.ff.weight)
    tgt_net.model.ff.weight.data += 1
    print("After update2")
    print("Main net:", net.ff.weight)
    print("Target net:", tgt_net.target_model.ff.weight)
    tgt_net.sync()                          # 将net的权重复制给target_model,model是net的一个引用,共享地址
    print("After sync")
    print("Main net:", net.ff.weight)
    print("Target net:", tgt_net.target_model.ff.weight)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值