问题:
在训练A3C的过程中, 本来想用cuda进行多进程训练, 即
import torch
import torch.multiprocessing as mp
global_net = Net()
global_net.cuda()
global_net.share_memory
class Worker(mp.Process):
def __init__(net):
super().__init__()
self.net = global_net
def run():
......
但将self.net
的模型参数打印出来就会发现全都是0, 不符合预期要求.
解决方法
将global_net
放入cpu中在运行即可.