def adjust_net(order):
net = torch.nn.Sequential(
torch.nn.Linear(371, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 512),
torch.nn.LeakyReLU(0.01),
torch.nn.Linear(512, 371),
)
net.load_state_dict(
torch.load(f'ugb-embedded-230504/ugb_ckpt/model_bidding_sl/{order:02d}.ckpt', map_location='cpu')
)
left = order if order > 0 else 1
right = 8 / 13
torch.set_grad_enabled(False)
for layer in net.modules():
if isinstance(layer, torch.nn.Linear):
print(order, layer)
for weight in layer.parameters():
for j in range(512):
for i in range(371):
if i < 319:
weight[j, i] /= left
else:
weight[j, i] /= right
torch.set_grad_enabled(True)
torch.save(
net.state_dict(), f'ugb-embedded-230504/ugb_ckpt/model_bidding_rl/{order:02d}.ckpt'
)
return
for order in range(15):
adjust_net(order)
04-20
1700
02-27
505
12-03
2490
03-04
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交