读入模型参数-且不占用显卡
see_memory_usage(‘message’)
# 4. 读入checkpiont参数
state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin', map_location='cpu')
# state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin')
如果不加map_location='cpu',读入的参数就会占用GPU
see_memory_usage(‘message’)
可以通过添加下面函数来检测GPU占用情况
import gc
import psutil
def see_memory_usage(message):
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
# Print message except when distributed but not rank 0
print(message)
print(f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB ")
vm_stats = psutil.virtual_memory()
used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
print(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
return torch.cuda.reset_peak_memory_stats()