def load_network(self, load_path, network, strict=True):
if isinstance(network, nn.DataParallel):
network = network.module
model_dict = torch.load(load_path)
filtered = {k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k}
network.load_state_dict(filtered, strict=strict)
# network.load_state_dict(torch.load(load_path), strict=strict)
pytorch 0.4版本加载0.4.1 1.0更高版本的model
最新推荐文章于 2023-04-07 14:34:02 发布