map_location
参数在 torch.load
函数中用于指定加载模型时的设备映射。它的作用是控制模型权重被加载到哪个设备上。
具体作用
- 避免显存不足: 在加载模型时,如果不指定
map_location
,默认会将模型加载到保存时所在的设备(通常是 GPU)。如果目标机器的 GPU 显存不足,则可能会引发内存不足错误。使用map_location="cpu"
可以将模型加载到 CPU 上,避免这种情况。 - 跨设备加载: 如果保存模型时使用的是 GPU,但在加载时需要使用 CPU 或者另一台没有 GPU 的机器,则可以通过
map_location
参数指定新的设备。 - 多设备环境: 在多设备环境中,有时需要将模型从一个 GPU 移动到另一个 GPU,或从 GPU 移动到 CPU,这时可以使用
map_location
参数进行设备映射。