ptan实战5 || TargetNet类
TargetNet
类的成员包含.model
和.target_model
,前者是NN的一个引用,共享空间,后者是将NN复制过来,是一个独立的网络
方法:
sync()
用于将NN的参数同步给target_net
alpha_sync()
源NN的参数以alpha的权重赋值给target_model,w=aw + (1-a)s ,w是target_model的权重,s是源NN的权重
import ptan
import torch.nn as nn
class DQNNet(nn.Module):
def __init__(self):
super(DQNNet, self).__init__()
self.ff = nn.Linear(5, 3)
def forward(self, x):
return self.ff(x)
if __name__ == "__main__":
net = DQNNet()
print(net)
tgt_net = ptan.agent.TargetNet(net) # 初始化ptan.agent.TargetNet类,需要传入模型NN
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)
net.ff.weight.data += 1.0
print("After update")
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)
tgt_net.model.ff.weight.data += 1
print("After update2")
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)
tgt_net.sync() # 将net的权重复制给target_model,model是net的一个引用,共享地址
print("After sync")
print("Main net:", net.ff.weight)
print("Target net:", tgt_net.target_model.ff.weight)