1. 计算网络模型的参数数量
def view_model_param(MODEL_NAME, net_params):
model = gnn_model(MODEL_NAME, net_params)
total_param = 0
print("MODEL DETAILS:\n")
#print(model)
for param in model.parameters():
# print(param.data.size())
total_param += np.prod(list(param.data.size()))
print('MODEL/Total parameters:', MODEL_NAME, total_param)
return total_param
这个函数返回的参数单位是“个”(所有参数的总数),怎么将其转化为MB(兆字节)或者K?
要将这个数字转换为 MB(兆字节)或 KB(千字节),需要考虑每个参数通常用 4 字节(32 位浮点数)或 8 字节(64 位浮点数)来表示。接下来,将总参数数量乘以每个参数所占用的字节数(例如 4 字节),然后除以相应的单位转换因子。
2. 查看模型参数的数据类型
在 PyTorch 中,可以通过检查模型参数的数据类型来确定它们是单精度(32 位浮点数)还是双精度(64 位浮点数)。
# 遍历模型的所有参数并检查数据类型:
for param in model.parameters():
print("Parameter data type:", param.dtype)
在 PyTorch 中,数据类型表示为:
- torch.float32 或 torch.float:单精度(32 位浮点数)
- torch.float64 或 torch.double:双精度(64 位浮点数)
通常情况下,神经网络模型中的参数默认为单精度(32位浮点数)。但是,在某些情况下,如梯度累积或更高精度的计算要求,您可能需要使用双精度(64位浮点数)。
3. 转化为字节 | MB(兆字节)或 KB(千字节)
在神经网络模型中,权重(参数)通常用浮点数表示。浮点数有不同的精度,例如单精度(32位)和双精度(64位)。根据所使用的精度,每个参数需要的存储空间也会有所不同。
-
32位浮点数(单精度):每个参数需要4个字节(8位一个字节,共32位)来存储。
-
64位浮点数(双精度):每个参数需要8个字节(8位一个字节,共64位)来存储。
-
为了计算模型参数所需的内存大小,我们需要知道每个参数所占用的字节数。根据所使用的精度,将总参数数量乘以每个参数所占用的字节数可以得到总字节数。
接下来,我们可以使用单位转换因子将字节数转换为其他单位(例如兆字节或千字节):
- 1 KB(千字节) = 1024字节
- 1 MB(兆字节) = 1024 KB = 1024 * 1024字节
因此,为了将总字节数转换为兆字节或千字节,我们需要将总字节数除以相应的单位转换因子。
这就是为什么我们需要考虑参数的字节大小和使用单位转换因子来计算模型参数所需的内存。
要将参数数量转换为兆字节(MB)和千字节(KB),您可以进行以下操作:
def view_model_param(MODEL_NAME, net_params):
model = gnn_model(MODEL_NAME, net_params)
total_param = 0
print("MODEL DETAILS:\n")
#print(model)
for param in model.parameters():
# print(param.data.size())
total_param += np.prod(list(param.data.size()))
print('MODEL/Total parameters:', MODEL_NAME, total_param)
# 假设每个参数是一个 32 位浮点数(4 字节)
bytes_per_param = 4
# 计算总字节数
total_bytes = total_param * bytes_per_param
# 转换为兆字节(MB)和千字节(KB)
total_megabytes = total_bytes / (1024 * 1024)
total_kilobytes = total_bytes / 1024
print("Total parameters in MB:", total_megabytes)
print("Total parameters in KB:", total_kilobytes)
return total_param
即设模型的参数数量为params,对于模型参数类型是foalt32而言:
KB(千字节)= parmas4/1024
MB(兆字节)= parmas4/1024/1024