最近在用pytorch训练图文相关性模型,图片特征使用resnet抽取,文案特征使用bert抽取,把这两个特征合并为一个特征送入浅层神经网络中。
数据量:1千万条。
机器配置:单机4块GPU
现状:单个GPU跑一轮需要15个小时左右。
pytorch1.1对单机多GPU支持的很好,直接一条命令解决:
nn.DataParallel(model).cuda()
方案1:
直接有两个模型生成特征,拼接为最终特征,把最终特征传给预训练模型,这个时候发现4个GPU使用情况不一样,GPU 0使用特别高,其他3个GPU使用比较低
方案2:
把resnet 和 bert模型以及评判准则内置到预训练模型中,发现4个GPU使用基本均匀了,一轮4个小时左右,提升很明显。
GPU上训练的模型保存为CPU可执行的模型:
torch.save(model.module.user_model.state_dict(), "model_state_dict")
model.load_state_dict(torch.load("model_state_dict"))