mindspore和pytorch的比较 -- 网络接口部分(二)

二、网络接口及支持的方法

功能mindspore Cell接口torch Module接口
查看模块cells_and_names()modules(),named_modules()功能一致
cells()children()略有差异
name_cells()name_children()略有差异
获取参数parameters_and_names(),get_parameters()state_dict(destination=None)略有差异
parameters_dict()/
/load_state_dict(state_dict)
需要优化trainable_params()named_parameters(), parameters()略有差异
不需优化untrainable_params(recurse=True)buffers(), named_buffers()略有差异
添加模块insert_child_to_cell(child_name, child_cell)add_module(name, module)功能一致
状态设置set_train(mode=True)train(mode=True), eval()功能一致
init_parameters_data(auto_parallel_mode=False)/
train_net.set_broadcast_flag()/
set_param_ps()/

1. 查看网络中子模块

1)查看外层子模块:功能相同,类型不同

  • mindspore: net.name_cells() / net.cells():OrderedDict / odict_values
  • net.named_children() / children():迭代器

2)查看全部模块:一致

  • mindspore: net.cells_and_names():迭代器
  • mindspore: net.named_modules() / modules():迭代器
# 查看外层子模块
# mindspore
for name, cell in net.name_cells().items():
  print("*** name: ", name)
  print("*** cell: ", cell)
  
# torch
for name, child in net.named_children(): 
    print("*** name: ", name)            
    print("*** child: ", child)         
# mindspore: output
*** name:  build_block1
*** cell:  SequentialCell<
  (0): ConvBNReLU<
    (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3),stride=(1, 1),  pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
    (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block1.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block1.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block1.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block1.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
    (relu): ReLU<>
    >
  (1): MaxPool2d<kernel_size=2, stride=1, pad_mode=VALID>
  >
*** name:  build_block2
*** cell:  CellList<
  (0): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4),stride=(1, 1),  pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=Falseweight_init=normal, bias_init=zeros, format=NCHW>
  (1): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block2.1.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block2.1.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block2.1.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block2.1.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
  (2): ReLU<>
  >
# torch: output
*** name:  build_block1
*** child:  Sequential(
  (0): ConvBNReLU(
    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
*** name:  build_block2
*** child:  ModuleList(
  (0): Conv2d(64, 4, kernel_size=(4, 4), stride=(1, 1))
  (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
# 查看所有子模块
# mindspore
for name, cell in net.cells_and_names():
  print("*** name: ", name)
  
# torch
for name, module in net.named_modules():   
    print("*** name: ", name)              
# output:mindspore
*** name:  
*** name:  build_block1
*** name:  build_block1.0
*** name:  build_block1.0.conv
*** name:  build_block1.0.bn
*** name:  build_block1.0.relu
*** name:  build_block1.1
*** name:  build_block2
*** name:  build_block2.0
*** name:  build_block2.1
*** name:  build_block2.2
# output: torch
*** name:  
*** name:  build_block1
*** name:  build_block1.0
*** name:  build_block1.0.conv
*** name:  build_block1.0.bn
*** name:  build_block1.0.relu
*** name:  build_block1.pool
*** name:  build_block2
*** name:  build_block2.0
*** name:  build_block2.1
*** name:  build_block2.2

1. 获取参数

1)全部参数

  • mindspore: net.parameters_and_names() / net.get_parameters() / net.parameters_dict():迭代器;迭代器;OrderedDict
  • torch: net.state_dict():OrderedDict

2)查看需要优化的参数

  • mindspore: net.trainable_params():迭代器
  • torch: net.named_parameters(), net.parameters():迭代器

3)查看不需要优化的参数

  • mindspore: net.untrainable_params():迭代器
  • torch: net.buffers(), net.named_buffers():迭代器
# mindspore: 查看所有参数
for name, params in net.parameters_and_names():
  print("*** name: ", name, "     *** params: ", params)
  
# torch: 查看所有参数
for name, state in net.state_dict().items():                      
    print("*** name: ", name, "     *** state: ", type(state))    
# mindspore output
*** name:  build_block1.0.conv.weight      *** params:  Parameter (name=build_block1.0.conv.weight, shape=(64, 3, 3, 3), dtype=Float32, requires_grad=True)
*** name:  build_block1.0.bn.moving_mean      *** params:  Parameter (name=build_block1.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False)
*** name:  build_block1.0.bn.moving_variance      *** params:  Parameter (name=build_block1.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)
*** name:  build_block1.0.bn.gamma      *** params:  Parameter (name=build_block1.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True)
*** name:  build_block1.0.bn.beta      *** params:  Parameter (name=build_block1.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True)
*** name:  build_block2.0.weight      *** params:  Parameter (name=build_block2.0.weight, shape=(4, 64, 4, 4), dtype=Float32, requires_grad=True)
*** name:  build_block2.1.moving_mean      *** params:  Parameter (name=build_block2.1.moving_mean, shape=(4,), dtype=Float32, requires_grad=False)
*** name:  build_block2.1.moving_variance      *** params:  Parameter (name=build_block2.1.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)
*** name:  build_block2.1.gamma      *** params:  Parameter (name=build_block2.1.gamma, shape=(4,), dtype=Float32, requires_grad=True)
*** name:  build_block2.1.beta      *** params:  Parameter (name=build_block2.1.beta, shape=(4,), dtype=Float32, requires_grad=True)
# torch output
*** name:  build_block1.0.conv.weight      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.conv.bias      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.weight      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.bias      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.running_mean      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.running_var      *** state:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.num_batches_tracked      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.0.weight      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.0.bias      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.1.weight      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.1.bias      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.1.running_mean      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.1.running_var      *** state:  <class 'torch.Tensor'>
*** name:  build_block2.1.num_batches_tracked      *** state:  <class 'torch.Tensor'>
# mindspore: 查看需要优化的参数
for params in net.trainable_params():
  print("*** params: ", params)
  
# torch: 查看需要优化的参数
for name, params in net.named_parameters():                          
    print("*** name: ", name, "     *** params: ", type(params))     
# mindspore output
*** params:  Parameter (name=build_block1.0.conv.weight, shape=(64, 3, 3, 3), dtype=Float32, requires_grad=True)
*** params:  Parameter (name=build_block1.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True)
*** params:  Parameter (name=build_block1.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True)
*** params:  Parameter (name=build_block2.0.weight, shape=(4, 64, 4, 4), dtype=Float32, requires_grad=True)
*** params:  Parameter (name=build_block2.1.gamma, shape=(4,), dtype=Float32, requires_grad=True)
*** params:  Parameter (name=build_block2.1.beta, shape=(4,), dtype=Float32, requires_grad=True)
# torch output
*** name:  build_block1.0.conv.weight      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block1.0.conv.bias      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block1.0.bn.weight      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block1.0.bn.bias      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block2.0.weight      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block2.0.bias      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block2.1.weight      *** params:  <class 'torch.nn.parameter.Parameter'>
*** name:  build_block2.1.bias      *** params:  <class 'torch.nn.parameter.Parameter'>
# mindspore: 查看不需要优化的参数
for params in net.untrainable_params():
  print("*** params: ", params)
  
# torch: 查看不需要优化的参数
for name, buffer in net.named_buffers():                                
    print("*** name: ", name, "     *** buffer: ", type(buffer))        
# mindspore output
*** params:  Parameter (name=build_block1.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False)
*** params:  Parameter (name=build_block1.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)
*** params:  Parameter (name=build_block2.1.moving_mean, shape=(4,), dtype=Float32, requires_grad=False)
*** params:  Parameter (name=build_block2.1.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)

# torch output
*** name:  build_block1.0.bn.running_mean      *** buffer:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.running_var      *** buffer:  <class 'torch.Tensor'>
*** name:  build_block1.0.bn.num_batches_tracked      *** buffer:  <class 'torch.Tensor'>
*** name:  build_block2.1.running_mean      *** buffer:  <class 'torch.Tensor'>
*** name:  build_block2.1.running_var      *** buffer:  <class 'torch.Tensor'>
*** name:  build_block2.1.num_batches_tracked      *** buffer:  <class 'torch.Tensor'>
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值