神经网络打印模型参数及参数名字和数量

神经网络打印模型参数及参数名字和数量

在设计和优化神经网络模型性能时,很多时候需要考虑模型的参数量和计算复杂度,下面一个栗子可以帮助我们快速查看模型的参数。
** 举个栗子,如有错误,欢迎大家批评指正 **
本文链接:神经网络打印模型参数及参数名字和数量
https://blog.csdn.net/leiduifan6944/article/details/103690228

exp:

import torch
from torch import nn


class Net(nn.Module):
	def __init__(self):
		super().__init__()
		self.fc1 = nn.Linear(3*4*4, 3*5*5)
		self.conv1 = nn.Sequential(
			nn.Conv2d(3, 4, 1, 1),		# conv1.0
			nn.BatchNorm2d(4),			# conv1.1
			nn.LeakyReLU(),				# conv1.2

			nn.Conv2d(4, 4, 3, 1),		# conv1.3
			nn.BatchNorm2d(4),			# conv1.4
			nn.LeakyReLU(),				# conv1.5
		)

		self.fc2 = nn.Linear(4*3*3, 10)

	def forward(self, entry):
		entry = entry.reshape(-1, 3*4*4)
		fc1_out = self.fc1(entry)
		fc1_out = fc1_out.reshape(-1, 3, 5, 5)
		conv1_out = self.conv1(fc1_out)
		conv1_out = conv1_out.reshape(-1, 4*3*3)
		fc2_out = self.fc2(conv1_out)

		return fc2_out


if __name__ == '__main__':
	x = torch.Tensor(2, 3, 4, 4)
	net = Net()

	out = net(x)
	print('%14s : %s' % ('out.shape', out.shape))
	print('---------------华丽丽的分隔线---------------')
	# -------------方法1--------------
	sum_ = 0
	for name, param in net.named_parameters():
		mul = 1
		for size_ in param.shape:
			mul *= size_							# 统计每层参数个数
		sum_ += mul									# 累加每层参数个数
		print('%14s : %s' % (name, param.shape))  	# 打印参数名和参数数量
		# print('%s' % param)						# 这样可以打印出参数,由于过多,我就不打印了
	print('参数个数:', sum_)						# 打印参数量
	
	# -------------方法2--------------
	for param in net.parameters():
		print(param.shape)
		# print(param)

	# -------------方法3--------------
	params = list(net.parameters())
	for param in params:
		print(param.shape)
		# print(param)

以下是方法1的输出效果:

(方法2和方法3没贴出效果,个人比较喜欢用方法1,因为可以看到当前打印的是哪一层网络的参数)

     out.shape : torch.Size([2, 10])
---------------华丽丽的分隔线---------------
    fc1.weight : torch.Size([75, 48])
      fc1.bias : torch.Size([75])
conv1.0.weight : torch.Size([4, 3, 1, 1])
  conv1.0.bias : torch.Size([4])
conv1.1.weight : torch.Size([4])
  conv1.1.bias : torch.Size([4])
conv1.3.weight : torch.Size([4, 4, 3, 3])
  conv1.3.bias : torch.Size([4])
conv1.4.weight : torch.Size([4])
  conv1.4.bias : torch.Size([4])
    fc2.weight : torch.Size([10, 36])
      fc2.bias : torch.Size([10])
参数个数: 4225

  • 5
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值