用pytorch计算神经网络模型需要占用的显存

用pytorch计算神经网络模型需要占用的显存

#--------------------------------------------------------------------#
#作用:计算模型需要占用的显存,方便知道显卡够不够用
#使用方法:将模型初始化之后,传入Calculate_gpu_memory()即可
#--------------------------------------------------------------------#

import torch
import numpy as np
import torchvision
import torch.nn as nn

def Calculate_gpu_memory(Model,train_batch_size,img_wide,img_height):
    print("----------------计算模型要占用的显存------------")
    #step1#------------------------------------------------------------------计算模型参数占用的显存
    type_size = 4 #因为参数是float32,也就是4B
    para = sum([np.prod(list(p.size())) for p in Model.parameters()])
    print("Model {}:params:{:4f}M".format(Model._get_name(),para * type_size/1000/1000))
    #step2#------------------------------------------------------------------------计算模型的中间变量会占用的显存
    input = torch.ones((train_batch_size, 3, img_wide, img_height))
    input.requires_grad_(requires_grad=False)
    #遍历模型的每一个网络层(注意:一般模型都是嵌套建立的,这里只考虑了小于等于2层嵌套结构)
    mods = list(Model.named_children())
    out_sizes = []
    for i in range(0, len(mods)):
            mod = list(mods[i][1].named_children())
            if mod != []:
                for j in range(0, len(mod)):
                    m = mod[j][1]
                    #注意这里,如果relu激活函数是inplace则不用计算
                    if isinstance(m,nn.ReLU):  
                        if m.inplace:
                            continue
                    print("网络层(不包括池化层,inplace为True的激活函数):",m)
                    try: #一般不会把展平操作记录到里面去,因为没有在__init__中初始化,所以这里需要加上,如果不加上,将不能继续计算
                        out = m(input)
                    except RuntimeError:
                        input = torch.flatten(input, 1)
                        out = m(input)
                    out_sizes.append(np.array(out.size()))
                    if mod[j][0] not in ["rpn_score","rpn_loc"]: 
                        input = out
            else:
                m = mods[i][1]
                #注意这里,如果relu激活函数是inplace则不用计算
                if isinstance(m,nn.ReLU):  
                    if m.inplace:
                        continue
                print("网络层(不包括池化层,inplace为True的激活函数):",m)
                try:
                    out = m(input)
                except RuntimeError:
                    input = torch.flatten(input, 1)
                    out = m(input)
                out_sizes.append(np.array(out.size()))

                if mods[j][0] not in ["rpn_score","rpn_loc"]:
                    input = out
    #统计每一层网络中间变量需要占用的显存
    total_nums = 0
    for i in range(len(out_sizes)):
        s = out_sizes[i]
        nums = np.prod(np.array(s))
        total_nums += nums
    print('Model {} : intermedite variables: {:3f} M (without backward)'
            .format(Model._get_name(), total_nums * type_size / 1000 / 1000))
    print('Model {} : intermedite variables: {:3f} M (with backward)'
            .format(Model._get_name(), total_nums * type_size*2 / 1000 / 1000))
    print("----------------显存计算完毕------------")


#------------------------------------------------------------------------#
#测试,下面的代码不会影响上面的函数被其他python文件导入
if __name__=="__main__":
    vgg16 = torchvision.models.vgg16(pretrained=False)
    print(vgg16)
    Calculate_gpu_memory(vgg16,4,448,448)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ai_Taoism

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值