文章目录
问题说明
- 需要在分布式节点间传递Pytorch模型参数,以实现分布式模型训练
- 需要为模型参数提供序列化和反序列化
- 需要支持所有模型,不是特定几种
传输tensor
Pytorch nn.Module模型参数在哪?
state_dict = AlexNet().state_dict()
# 计算参数大小
sum = 0
for key in state_dict.keys():
sum += sys.getsizeof(state_dict[key].storage())
- 模型的参数储存在state_dict()中,返回dict, key是可学习参数的名称,如“conv1.weight”, "fc8.bias"等,value就是对应的tensor, 默认FloatTensor类型.
错误的想法: 转换成list, 使用json dumps转换成字符串
def state_dict_serialization(tensor_dict)