【CANN训练营第三季】基于昇腾PyTorch框架的模型训练调优


在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

性能分析工具PyTorch Profiling

with torch.autograd.profiler.profile(use_cuda=True) as prof:
	out = model(input_tensor)
	loss=loss_func(out)
	loss.backward()
	optimizer.zero_grad()
	optimizer.step()
print(prof) # 打印profiler结果信息
prof.export_chrome_trace('profile_' + str(i) +'.json')

性能分析工具CANN Profiling

with torch.npu.profile('./cann_prof') as prof:
	out = model(input_tensor)
	loss=loss_func(out.target)
	loss.backward()
	optimizer.zero_grad()
	optimizer.step()
print(prof) # 打印profiler结果信息
prof.export_chrome_trace('profile_' + str(i) +'.json')

$ find /usr/local/Ascend/ascend-toolkit -name msprof.py
$ python3 msprof.py export summary -dir ./results/

例如,找到的在:
/usr/local/Ascend/ascend-toolkit/6.0.RC1/tools/profiler/profiler_tool/analysis/msprof/msprof.py
生成的profile文件夹cann_profiling

python3 /usr/local/Ascend/ascend-toolkit/6.0.RC1/tools/profiler/profiler_tool/analysis/msprof/msprof.py export summary -dir ~/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0

结业考核

1、使用Pytorch实现LeNet网络的minist手写数字识别。

硬件平台不限,可以基于windows或者linux系统,尽量给出整个过程的截图,并在最后给出loss或者accuracy运行结果,提供打印loss和accuracy日志,给出截图。【10分】

参考链接Github:https://github.com/allegrofb/LeNet.git

git clone https://github.com/allegrofb/LeNet.git 

直接用本地wsl的gpu运行

在这里插入图片描述

22:37:40 --- Epoch: 0	Train loss: 0.2438	Valid loss: 0.0851	Train accuracy: 97.05	Valid accuracy: 97.24
22:38:09 --- Epoch: 1	Train loss: 0.0787	Valid loss: 0.0576	Train accuracy: 98.40	Valid accuracy: 98.35

在这里插入图片描述

2、采用课程中学习到的手工或者自动迁移方式,将上述脚本迁移到昇腾npu上,单机单卡,提供迁移脚本,突出关键点并截图。【10分】

采用工具迁移
先在云服务器上完成迁移

#·进入脚本转换工具所在路径
cd /home/HwHiAiUser/Ascend/ascend-toolkit/latest/tools/ms_fmk_transplt
# 执行脚本转换。
./pytorch_gpu2npu.sh -i /home/HwHiAiUser/LeNet -o /home/HwHiAiUser/LeNet_output

在这里插入图片描述
或者可以采用一句话自动迁移的方式,不需要使用前面提到的工具,直接在脚本中添加:

from torch_npu.contrib import transfer_to_npu

在这里插入图片描述
这样就开始运行了。
要想跑通的话,需要采用课程上给的notebook镜像:
swr.cn-north-4.myhuaweicloud.com/atelier/pytorch_1_8_ascend:pytorch_1.8.1-cann_6.0.0-py_3.7-euler_2.8.3-aarch64-d910-20221116111529
参考文档:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha003/ptmoddevg/ptmigr/ptmigr_000002.html

在这里插入图片描述
在npu上得到的精度

18:01:57 — Epoch: 0 Train loss: 0.2450 Valid loss: 0.0777 Train accuracy: 97.43 Valid accuracy: 97.48
18:02:56 — Epoch: 1 Train loss: 0.0792 Valid loss: 0.0597 Train accuracy: 97.98 Valid accuracy: 98.22

在这里插入图片描述

3、在完成前两题的基础上,调大batchsize与lr,batchsize尽可能的大,模型最终精度达标,给出日志或者截图。【20分】

调整batch_size为128,learning_rate为0.001
18:14:34 — Epoch: 0 Train loss: 0.4303 Valid loss: 0.1372 Train accuracy: 95.22 Valid accuracy: 95.71
18:15:14 — Epoch: 1 Train loss: 0.1162 Valid loss: 0.0763 Train accuracy: 97.45 Valid accuracy: 97.67
在这里插入图片描述
通过保存模型,再次训练,或者提高epoch,可以提升精度
多次训练后,模型已经收敛
在这里插入图片描述

  • 我的代码
    https://gitee.com/qmckw/lenet_msft

4、在npu环境上使用大batchsize训练,使用混合精度加速模型的训练,并打印模型性能数据。【40分】

评分明细:
参考:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha003/ptmoddevg/ptmigr/ptmigr_000017.html

  1. 完成混合精度训练。【20分】
    增加逻辑:
    在这里插入图片描述
    也可以选择O1或者O2
    在这里插入图片描述
    便于观察,我设置了10epoch
    在这里插入图片描述

  2. 使用cann profling与pytorch profiling工具打出模型性能数据,提供profiling文件。【两个文件,各10分,共20分】

  • cann profling
    参考:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha003/developmenttools/devtool/atlasprofiling_16_0091.html
    在这里插入图片描述
    进行分析:
(PyTorch-1.8.1) [ma-user LeNet]$python3 /usr/local/Ascend/ascend-toolkit/6.0.RC1/tools/profiler/profiler_tool/analysis/msprof/msprof.py export summary -dir ~/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0
Sun 08 Jan 2023 22:21:11 [INFO] [MSVP] [128429] msprof_export.py: The data in "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0" has been analyzed.
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export task_time summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The task_time summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/task_time_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export ge_op_execute summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The ge_op_execute summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/ge_op_execute_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export op_summary summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The op_summary summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/op_summary_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export op_statistic summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The op_statistic summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/op_statistic_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export acl summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The acl summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/acl_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export acl_statistic summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The acl_statistic summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/acl_statistic_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export runtime_api summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The runtime_api summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/runtime_api_0_1.csv".
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: Start to export ai_stack_time summary data ...
Sun 08 Jan 2023 22:21:12 [INFO] [MSVP] [128429] msprof_export.py: The ai_stack_time summary data of device 0 for iteration 1 has been exported to "/home/ma-user/work/LeNet/cann_profiling/PROF_000005_20230108221343132_OAGALRFAGNGFJNKC/device_0/summary/ai_stack_time_0_1.csv".

Performance Summary Report:
1. Model/Operator Computation:
        N/A
2. Model/Operator Memory:
        N/A
3. Operator Schedule:
        1)Task wait time has reached the upper limit: [Conv2D,RealDiv,trans_TransData_30,trans_TransData_33,LogSoftmaxV2,trans_TransData_39,trans_TransData_46,trans_TransData_26,trans_TransData_48,trans_TransData_49,trans_TransData_50,SoftmaxV2_out,trans_TransData_32,trans_TransData_37,NLLLoss,trans_TransData_29,trans_TransData_38,NLLLossGrad,LogSoftmaxGrad,Sqrt,AvgPoolV2AvgPoolV2_mul_layer,ReluGrad,AvgPoolV2Grad,atomic_addr_clean-1_67_6_0,trans_TransData_51,trans_TransData_52,Add,trans_TransData_25,ZerosLike,Addcdiv,Slice,trans_TransData_35,trans_TransData_42,atomic_addr_clean-1_93_10_0,AxpyV2,atomic_addr_clean-1_91_9_0,SoftmaxV2_input,MatMul,Addcmul,trans_TransData_47,atomic_addr_clean-1_50_5_0,OnesLike,trans_TransData_11,trans_TransData_31,trans_TransData_24,atomic_addr_clean-1_77_8_0,trans_TransData_36,SoftmaxV2_new,trans_TransData_21,ReduceSum,atomic_addr_clean-1_72_7_0,Relu,trans_TransData_27,Cast,Mul,trans_TransData_28,NLLLossDiv,Conv2DBackpropInput,trans_TransData_45,trans_TransData_34,Conv2DBackpropFilter]
4. Operator Processing:
        1)please check and reduce the transData: [trans_TransData_30,trans_TransData_33,trans_TransData_39,trans_TransData_46,trans_TransData_48,trans_TransData_26,trans_TransData_49,trans_TransData_50,trans_TransData_32,trans_TransData_37,trans_TransData_29,trans_TransData_38,trans_TransData_52,trans_TransData_51,trans_TransData_25,trans_TransData_35,trans_TransData_42,trans_TransData_47,trans_TransData_11,trans_TransData_31,trans_TransData_24,trans_TransData_36,trans_TransData_21,trans_TransData_27,trans_TransData_28,trans_TransData_45,trans_TransData_34]
5. Operator Metrics:
        N/A

我们可以对生成的csv文件进行分析
在这里插入图片描述

  • pytorch profiling
    参考:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha003/developmenttools/devtool/atlasprofiling_16_0090.html

打印结果:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self NPU    Self NPU %     NPU total  NPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::is_floating_point         0.00%      26.680us         0.00%      26.680us      26.680us      17.000us         0.00%      17.000us      17.000us             1  
                                               aten::to         0.00%     106.952us         0.00%     195.353us     195.353us      94.000us         0.00%     194.000us     194.000us             1  
                                                   Cast         0.00%      88.401us         0.00%      88.401us      88.401us     100.000us         0.00%     100.000us     100.000us             1  
                                           aten::conv2d         0.00%      31.140us         0.00%     290.574us     290.574us      17.000us         0.00%     288.000us     288.000us             1  
                                      aten::convolution         0.00%      24.651us         0.00%     259.434us     259.434us      32.000us         0.00%     271.000us     271.000us             1  
                                     aten::_convolution         0.00%     167.602us         0.00%     234.783us     234.783us     161.000us         0.00%     239.000us     239.000us             1  
                                                 Conv2D         0.00%      67.181us         0.00%      67.181us      67.181us      78.000us         0.00%      78.000us      78.000us             1  
                                             aten::relu         0.00%      90.342us         0.00%     133.372us     133.372us      68.000us         0.00%     128.000us     128.000us             1  
                                                   Relu         0.00%      43.030us         0.00%      43.030us      43.030us      60.000us         0.00%      60.000us      60.000us             1  
                                       aten::avg_pool2d         0.00%      94.911us         0.00%     149.542us     149.542us      79.000us         0.00%     147.000us     147.000us             1  
                                              AvgPoolV2         0.00%      54.631us         0.00%      54.631us      54.631us      68.000us         0.00%      68.000us      68.000us             1  
                                           aten::conv2d         0.00%      26.310us         0.00%     209.243us     209.243us      23.000us         0.00%     206.000us     206.000us             1  
                                      aten::convolution         0.00%      23.431us         0.00%     182.933us     182.933us      25.000us         0.00%     183.000us     183.000us             1  
                                     aten::_convolution         0.00%     106.871us         0.00%     159.502us     159.502us      99.000us         0.00%     158.000us     158.000us             1  
                                                 Conv2D         0.00%      52.631us         0.00%      52.631us      52.631us      59.000us         0.00%      59.000us      59.000us             1  
                                             aten::relu         0.00%      73.661us         0.00%     112.682us     112.682us      60.000us         0.00%     114.000us     114.000us             1  
                                                   Relu         0.00%      39.021us         0.00%      39.021us      39.021us      54.000us         0.00%      54.000us      54.000us             1  
                                       aten::avg_pool2d         0.00%      89.632us         0.00%     140.212us     140.212us      87.000us         0.00%     145.000us     145.000us             1  
                                              AvgPoolV2         0.00%      50.580us         0.00%      50.580us      50.580us      58.000us         0.00%      58.000us      58.000us             1  
                                           aten::conv2d         0.00%      26.331us         0.00%     208.383us     208.383us      18.000us         0.00%     200.000us     200.000us             1  
                                      aten::convolution         0.00%      22.710us         0.00%     182.052us     182.052us      23.000us         0.00%     182.000us     182.000us             1  
                                     aten::_convolution         0.00%     103.941us         0.00%     159.342us     159.342us      93.000us         0.00%     159.000us     159.000us             1  
                                                 Conv2D         0.00%      55.401us         0.00%      55.401us      55.401us      66.000us         0.00%      66.000us      66.000us             1  
                                             aten::relu         0.00%      72.162us         0.00%     110.692us     110.692us      62.000us         0.00%     112.000us     112.000us             1  
                                                   Relu         0.00%      38.530us         0.00%      38.530us      38.530us      50.000us         0.00%      50.000us      50.000us             1  
                                          aten::flatten         0.00%      36.141us         0.00%      95.462us      95.462us      37.000us         0.00%      92.000us      92.000us             1  
                                             aten::view         0.00%      59.321us         0.00%      59.321us      59.321us      55.000us         0.00%      55.000us      55.000us             1  
                                           aten::linear         0.00%      50.940us         0.00%     846.912us     846.912us      48.000us         0.00%     847.000us     847.000us             1  
                                                aten::t         0.00%      65.962us         0.00%     112.032us     112.032us      67.000us         0.00%     112.000us     112.000us             1  
                                        aten::transpose         0.00%      29.560us         0.00%      46.070us      46.070us      28.000us         0.00%      45.000us      45.000us             1  
                                       aten::as_strided         0.00%      16.510us         0.00%      16.510us      16.510us      17.000us         0.00%      17.000us      17.000us             1  
                                            aten::addmm         0.00%     115.732us         0.00%     683.940us     683.940us     108.000us         0.00%     687.000us     687.000us             1  
                                              aten::mul         0.00%      91.380us         0.00%     397.235us     397.235us      84.000us         0.00%     402.000us     402.000us             1  
                                    format_contiguousV2         0.00%      81.461us         0.00%     140.422us     140.422us      82.000us         0.00%     145.000us     145.000us             1  
                                               Identity         0.00%      45.250us         0.00%      45.250us      45.250us      52.000us         0.00%      52.000us      52.000us             1  
                                             aten::set_         0.00%      13.711us         0.00%      13.711us      13.711us      11.000us         0.00%      11.000us      11.000us             1  
                                               aten::to         0.00%      48.111us         0.00%     116.512us     116.512us      46.000us         0.00%     116.000us     116.000us             1  
                                    aten::empty_strided         0.00%      16.420us         0.00%      16.420us      16.420us      15.000us         0.00%      15.000us      15.000us             1  
                                            aten::copy_         0.00%      51.981us         0.00%      51.981us      51.981us      55.000us         0.00%      55.000us      55.000us             1  
                                                    Mul         0.00%      48.921us         0.00%      48.921us      48.921us      57.000us         0.00%      57.000us      57.000us             1  
                                               aten::mm         0.00%      43.051us         0.00%      90.642us      90.642us      39.000us         0.00%      93.000us      93.000us             1  
                                                 MatMul         0.00%      47.591us         0.00%      47.591us      47.591us      54.000us         0.00%      54.000us      54.000us             1  
                                              aten::add         0.00%      38.400us         0.00%      80.331us      80.331us      33.000us         0.00%      84.000us      84.000us             1  
                                                    Add         0.00%      41.931us         0.00%      41.931us      41.931us      51.000us         0.00%      51.000us      51.000us             1  
                                             aten::relu         0.00%      84.181us         0.00%     124.172us     124.172us      76.000us         0.00%     119.000us     119.000us             1  
                                                   Relu         0.00%      39.991us         0.00%      39.991us      39.991us      43.000us         0.00%      43.000us      43.000us             1  
                                           aten::linear         0.00%      48.340us         0.00%     601.718us     601.718us      50.000us         0.00%     599.000us     599.000us             1  
                                                aten::t         0.00%      59.202us         0.00%      98.972us      98.972us      51.000us         0.00%      97.000us      97.000us             1  
                                        aten::transpose         0.00%      25.790us         0.00%      39.770us      39.770us      32.000us         0.00%      46.000us      46.000us             1  
                                       aten::as_strided         0.00%      13.980us         0.00%      13.980us      13.980us      14.000us         0.00%      14.000us      14.000us             1  
                                            aten::addmm         0.00%     107.072us         0.00%     454.406us     454.406us      98.000us         0.00%     452.000us     452.000us             1  
                                              aten::mul         0.00%      56.800us         0.00%     182.502us     182.502us      55.000us         0.00%     182.000us     182.000us             1  
                                               aten::to         0.00%      41.531us         0.00%      78.741us      78.741us      39.000us         0.00%      78.000us      78.000us             1  
                                    aten::empty_strided         0.00%      14.570us         0.00%      14.570us      14.570us      14.000us         0.00%      14.000us      14.000us             1  
                                            aten::copy_         0.00%      22.640us         0.00%      22.640us      22.640us      25.000us         0.00%      25.000us      25.000us             1  
                                                    Mul         0.00%      46.961us         0.00%      46.961us      46.961us      49.000us         0.00%      49.000us      49.000us             1  
                                               aten::mm         0.00%      40.161us         0.00%      86.211us      86.211us      46.000us         0.00%      92.000us      92.000us             1  
                                                 MatMul         0.00%      46.050us         0.00%      46.050us      46.050us      46.000us         0.00%      46.000us      46.000us             1  
                                              aten::add         0.00%      35.570us         0.00%      78.621us      78.621us      31.000us         0.00%      80.000us      80.000us             1  
                                                    Add         0.00%      43.051us         0.00%      43.051us      43.051us      49.000us         0.00%      49.000us      49.000us             1  
                                          aten::softmax         0.00%      35.402us         0.00%     165.983us     165.983us      32.000us         0.00%     163.000us     163.000us             1  
                                         aten::_softmax         0.00%      87.340us         0.00%     130.581us     130.581us      83.000us         0.00%     131.000us     131.000us             1  
                                              SoftmaxV2         0.00%      43.241us         0.00%      43.241us      43.241us      48.000us         0.00%      48.000us      48.000us             1  
                                aten::is_floating_point         0.00%      18.670us         0.00%      18.670us      18.670us      16.000us         0.00%      16.000us      16.000us             1  
                                               aten::to         0.00%      64.051us         0.00%     108.732us     108.732us      56.000us         0.00%     102.000us     102.000us             1  
                                                   Cast         0.00%      44.681us         0.00%      44.681us      44.681us      46.000us         0.00%      46.000us      46.000us             1  
                                aten::is_floating_point         0.00%      17.860us         0.00%      17.860us      17.860us      16.000us         0.00%      16.000us      16.000us             1  
                                               aten::to         0.00%      59.872us         0.00%      96.332us      96.332us      55.000us         0.00%      95.000us      95.000us             1  
                                                   Cast         0.00%      36.460us         0.00%      36.460us      36.460us      40.000us         0.00%      40.000us      40.000us             1  
                                      aten::log_softmax         0.00%      38.860us         0.00%     179.312us     179.312us      34.000us         0.00%     175.000us     175.000us             1  
                                     aten::_log_softmax         0.00%      96.161us         0.00%     140.452us     140.452us      95.000us         0.00%     141.000us     141.000us             1  
                                           LogSoftmaxV2         0.00%      44.291us         0.00%      44.291us      44.291us      46.000us         0.00%      46.000us      46.000us             1  
                                         aten::nll_loss         0.00%      36.780us         0.00%     379.805us     379.805us      43.000us         0.00%     382.000us     382.000us             1  
                                 aten::nll_loss_forward         0.00%     163.702us         0.00%     343.025us     343.025us     161.000us         0.00%     339.000us     339.000us             1  
                                             aten::ones         0.00%      42.612us         0.00%      85.962us      85.962us      40.000us         0.00%      86.000us      86.000us             1  
                                               OnesLike         0.00%      43.350us         0.00%      43.350us      43.350us      46.000us         0.00%      46.000us      46.000us             1  
                                                   Cast         0.00%      39.990us         0.00%      39.990us      39.990us      41.000us         0.00%      41.000us      41.000us             1  
                                                NLLLoss         0.00%      53.371us         0.00%      53.371us      53.371us      51.000us         0.00%      51.000us      51.000us             1  
                                        aten::ones_like         0.00%      66.860us         0.00%     124.081us     124.081us      29.000us         0.00%     110.000us     110.000us             1  
                                               OnesLike         0.00%      57.221us         0.00%      57.221us      57.221us      81.000us         0.00%      81.000us      81.000us             1  
                                        NllLossBackward         0.00%       1.180ms        11.57%        2.814s        2.814s       1.097ms         0.00%        2.814s        2.814s             1  
                                aten::nll_loss_backward         0.00%     347.184us        11.56%        2.813s        2.813s     283.000us         0.00%        2.813s        2.813s             1  
                                             aten::ones         0.00%      88.121us         0.00%     235.363us     235.363us      85.000us         0.00%     230.000us     230.000us             1  
                                               OnesLike         0.00%     147.242us         0.00%     147.242us     147.242us     145.000us         0.00%     145.000us     145.000us             1  
                                               aten::to         0.00%     104.962us         0.00%     156.933us     156.933us     105.000us         0.00%     157.000us     157.000us             1  
                                                   Cast         0.00%      51.971us         0.00%      51.971us      51.971us      52.000us         0.00%      52.000us      52.000us             1  
                                            NLLLossGrad        11.56%        2.813s        11.56%        2.813s        2.813s        2.813s        11.56%        2.813s        2.813s             1  
                                     LogSoftmaxBackward         0.00%     113.962us         5.15%        1.252s        1.252s     113.000us         0.00%        1.252s        1.252s             1  
                       aten::_log_softmax_backward_data         0.00%     216.022us         5.15%        1.252s        1.252s     173.000us         0.00%        1.252s        1.252s             1  
                                         LogSoftmaxGrad         5.14%        1.252s         5.14%        1.252s        1.252s        1.252s         5.14%        1.252s        1.252s             1  
torch::autograd::CppNode<at_npu::native::NPUDtypeCas...         0.00%     199.744us         4.55%        1.107s        1.107s     162.000us         0.00%        1.107s        1.107s             1  
                                                   Cast         4.55%        1.107s         4.55%        1.107s        1.107s        1.107s         4.55%        1.107s        1.107s             1  
                                          AddmmBackward         0.00%     256.195us        12.86%        3.130s        3.130s     252.500us         0.00%        3.130s        3.130s             1  
                                                aten::t         0.00%      88.221us         0.00%     159.852us     159.852us      82.000us         0.00%     160.000us     160.000us             1  
                                        aten::transpose         0.00%      42.871us         0.00%      71.631us      71.631us      50.000us         0.00%      78.000us      78.000us             1  
                                       aten::as_strided         0.00%      28.760us         0.00%      28.760us      28.760us      28.000us         0.00%      28.000us      28.000us             1  
                                             aten::conj         0.00%      23.690us         0.00%      23.690us      23.690us      24.000us         0.00%      24.000us      24.000us             1  
                                               aten::mm         0.00%     225.013us         6.71%        1.632s        1.632s     184.000us         0.00%        1.632s        1.632s             1  
                                                 MatMul         6.71%        1.632s         6.71%        1.632s        1.632s        1.632s         6.71%        1.632s        1.632s             1  
                                                aten::t         0.00%     104.941us         0.00%     183.072us     183.072us      97.000us         0.00%     178.000us     178.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 24.329s
Self NPU time total: 24.330s

在这里插入图片描述
在这里插入图片描述
我们可以看到,由于使用了combine_grad=True,多次add、mul等操作可以合为一次操作,也是视频https://www.bilibili.com/video/BV1AG4y1o73U/?spm_id_from=333.788&vd_source=58d010759cc2b1d5bc7753dd8aad0710
中的16分钟老师讲到的。

5、使用融合优化器与tensor融合对模型训练进行加速。

使用pytorch profiling工具打出使用融合优化后的模型与原profiling对比,找到性能优化的点,提供文字说明【20分】

参考:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha003/developmenttools/devtool/atlasprofiling_16_0090.html

  • 融合优化器:
    替换亲和优化器函数。
    适配后的APEX针对adadelta/adam/sgd/lamb做了昇腾AI处理器亲和性优化,得到的NPU融合优化器与原生算法保持一致,但运算速度更快。使用时只需将原有优化器替换为apex.optimizers.***,***为优化器名称,例如NpuFusedSGD。
    在这里插入图片描述
  • tensor融合:
    在这里插入图片描述
    在这里插入图片描述
    可以发现,输出信息会多一个group num
    融合后打印的信息,明显可以看到时间少了很多,时间从20多秒变到了1秒多:
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self NPU    Self NPU %     NPU total  NPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::is_floating_point         0.00%      24.291us         0.00%      24.291us      24.291us      18.000us         0.00%      18.000us      18.000us             1  
                                               aten::to         0.01%     195.962us         0.04%     592.988us     592.988us     188.000us         0.01%     586.000us     586.000us             1  
                                                   Cast         0.03%     397.026us         0.03%     397.026us     397.026us     398.000us         0.03%     398.000us     398.000us             1  
                                           aten::conv2d         0.00%      37.151us         0.05%     692.320us     692.320us      38.000us         0.00%     692.000us     692.000us             1  
                                      aten::convolution         0.00%      48.190us         0.04%     655.169us     655.169us      39.000us         0.00%     654.000us     654.000us             1  
                                     aten::_convolution         0.01%     212.093us         0.04%     606.979us     606.979us     213.000us         0.01%     615.000us     615.000us             1  
                                                 Conv2D         0.03%     394.886us         0.03%     394.886us     394.886us     402.000us         0.03%     402.000us     402.000us             1  
                                             aten::relu         0.01%     102.352us         0.02%     355.215us     355.215us      89.000us         0.01%     350.000us     350.000us             1  
                                                   Relu         0.02%     252.863us         0.02%     252.863us     252.863us     261.000us         0.02%     261.000us     261.000us             1  
                                       aten::avg_pool2d         0.01%     106.632us         0.03%     389.186us     389.186us      98.000us         0.01%     390.000us     390.000us             1  
                                              AvgPoolV2         0.02%     282.554us         0.02%     282.554us     282.554us     292.000us         0.02%     292.000us     292.000us             1  
                                           aten::conv2d         0.00%      24.911us         0.03%     418.326us     418.326us      24.000us         0.00%     416.000us     416.000us             1  
                                      aten::convolution         0.00%      21.710us         0.03%     393.415us     393.415us      19.000us         0.00%     392.000us     392.000us             1  
                                     aten::_convolution         0.01%     127.022us         0.02%     371.705us     371.705us     121.000us         0.01%     373.000us     373.000us             1  
                                                 Conv2D         0.02%     244.683us         0.02%     244.683us     244.683us     252.000us         0.02%     252.000us     252.000us             1  
                                             aten::relu         0.01%      84.621us         0.01%     220.423us     220.423us      80.000us         0.01%     221.000us     221.000us             1  
                                                   Relu         0.01%     135.802us         0.01%     135.802us     135.802us     141.000us         0.01%     141.000us     141.000us             1  
                                       aten::avg_pool2d         0.01%     101.191us         0.02%     311.174us     311.174us      99.000us         0.01%     317.000us     317.000us             1  
                                              AvgPoolV2         0.01%     209.983us         0.01%     209.983us     209.983us     218.000us         0.01%     218.000us     218.000us             1  
                                           aten::conv2d         0.00%      26.080us         0.03%     458.256us     458.256us      23.000us         0.00%     455.000us     455.000us             1  
                                      aten::convolution         0.00%      23.870us         0.03%     432.176us     432.176us      20.000us         0.00%     432.000us     432.000us             1  
                                     aten::_convolution         0.01%     137.002us         0.03%     408.306us     408.306us     132.000us         0.01%     412.000us     412.000us             1  
                                                 Conv2D         0.02%     271.304us         0.02%     271.304us     271.304us     280.000us         0.02%     280.000us     280.000us             1  
                                             aten::relu         0.01%      84.741us         0.02%     227.763us     227.763us      81.000us         0.01%     225.000us     225.000us             1  
                                                   Relu         0.01%     143.022us         0.01%     143.022us     143.022us     144.000us         0.01%     144.000us     144.000us             1  
                                          aten::flatten         0.00%      37.760us         0.01%      97.661us      97.661us      40.000us         0.00%      95.000us      95.000us             1  
                                             aten::view         0.00%      59.901us         0.00%      59.901us      59.901us      55.000us         0.00%      55.000us      55.000us             1  
                                           aten::linear         0.00%      66.613us         0.10%       1.448ms       1.448ms      58.000us         0.00%       1.436ms       1.436ms             1  
                                                aten::t         0.00%      61.540us         0.01%     111.371us     111.371us      60.000us         0.00%     111.000us     111.000us             1  
                                        aten::transpose         0.00%      30.560us         0.00%      49.831us      49.831us      30.000us         0.00%      51.000us      51.000us             1  
                                       aten::as_strided         0.00%      19.271us         0.00%      19.271us      19.271us      21.000us         0.00%      21.000us      21.000us             1  
                                            aten::addmm         0.01%     123.260us         0.08%       1.270ms       1.270ms     103.000us         0.01%       1.267ms       1.267ms             1  
                                              aten::mul         0.01%     100.721us         0.05%     691.340us     691.340us      96.000us         0.01%     694.000us     694.000us             1  
                                    format_contiguousV2         0.01%      95.401us         0.02%     274.974us     274.974us      95.000us         0.01%     275.000us     275.000us             1  
                                               Identity         0.01%     163.383us         0.01%     163.383us     163.383us     169.000us         0.01%     169.000us     169.000us             1  
                                             aten::set_         0.00%      16.190us         0.00%      16.190us      16.190us      11.000us         0.00%      11.000us      11.000us             1  
                                               aten::to         0.00%      64.660us         0.01%     129.292us     129.292us      61.000us         0.00%     125.000us     125.000us             1  
                                    aten::empty_strided         0.00%      18.001us         0.00%      18.001us      18.001us      16.000us         0.00%      16.000us      16.000us             1  
                                            aten::copy_         0.00%      46.631us         0.00%      46.631us      46.631us      48.000us         0.00%      48.000us      48.000us             1  
                                                    Mul         0.01%     186.353us         0.01%     186.353us     186.353us     198.000us         0.01%     198.000us     198.000us             1  
                                               aten::mm         0.00%      53.251us         0.02%     255.274us     255.274us      53.000us         0.00%     255.000us     255.000us             1  
                                                 MatMul         0.01%     202.023us         0.01%     202.023us     202.023us     202.000us         0.01%     202.000us     202.000us             1  
                                              aten::add         0.00%      49.091us         0.01%     200.273us     200.273us      36.000us         0.00%     215.000us     215.000us             1  
                                                    Add         0.01%     151.182us         0.01%     151.182us     151.182us     179.000us         0.01%     179.000us     179.000us             1  
                                             aten::relu         0.01%      91.631us         0.01%     209.563us     209.563us      90.000us         0.01%     209.000us     209.000us             1  
                                                   Relu         0.01%     117.932us         0.01%     117.932us     117.932us     119.000us         0.01%     119.000us     119.000us             1  
                                           aten::linear         0.00%      47.990us         0.07%       1.087ms       1.087ms      37.000us         0.00%       1.082ms       1.082ms             1  
                                                aten::t         0.00%      62.412us         0.01%     104.772us     104.772us      65.000us         0.00%     106.000us     106.000us             1  
                                        aten::transpose         0.00%      26.400us         0.00%      42.360us      42.360us      25.000us         0.00%      41.000us      41.000us             1  
                                       aten::as_strided         0.00%      15.960us         0.00%      15.960us      15.960us      16.000us         0.00%      16.000us      16.000us             1  
                                            aten::addmm         0.01%     116.582us         0.06%     934.083us     934.083us     128.000us         0.01%     939.000us     939.000us             1  
                                              aten::mul         0.00%      66.741us         0.02%     345.084us     345.084us      65.000us         0.00%     346.000us     346.000us             1  
                                               aten::to         0.00%      46.010us         0.01%      84.231us      84.231us      41.000us         0.00%      83.000us      83.000us             1  
                                    aten::empty_strided         0.00%      14.251us         0.00%      14.251us      14.251us      15.000us         0.00%      15.000us      15.000us             1  
                                            aten::copy_         0.00%      23.970us         0.00%      23.970us      23.970us      27.000us         0.00%      27.000us      27.000us             1  
                                                    Mul         0.01%     194.112us         0.01%     194.112us     194.112us     198.000us         0.01%     198.000us     198.000us             1  
                                               aten::mm         0.00%      50.711us         0.02%     289.744us     289.744us      43.000us         0.00%     288.000us     288.000us             1  
                                                 MatMul         0.02%     239.033us         0.02%     239.033us     239.033us     245.000us         0.02%     245.000us     245.000us             1  
                                              aten::add         0.00%      43.991us         0.01%     182.673us     182.673us      31.000us         0.00%     177.000us     177.000us             1  
                                                    Add         0.01%     138.682us         0.01%     138.682us     138.682us     146.000us         0.01%     146.000us     146.000us             1  
                                          aten::softmax         0.00%      36.540us         0.02%     268.704us     268.704us      33.000us         0.00%     267.000us     267.000us             1  
                                         aten::_softmax         0.01%     101.522us         0.02%     232.164us     232.164us     103.000us         0.01%     234.000us     234.000us             1  
                                              SoftmaxV2         0.01%     130.642us         0.01%     130.642us     130.642us     131.000us         0.01%     131.000us     131.000us             1  
                                aten::is_floating_point         0.00%      19.770us         0.00%      19.770us      19.770us      20.000us         0.00%      20.000us      20.000us             1  
                                               aten::to         0.00%      73.920us         0.02%     235.703us     235.703us      83.000us         0.01%     238.000us     238.000us             1  
                                                   Cast         0.01%     161.783us         0.01%     161.783us     161.783us     155.000us         0.01%     155.000us     155.000us             1  
                                aten::is_floating_point         0.00%      17.360us         0.00%      17.360us      17.360us      14.000us         0.00%      14.000us      14.000us             1  
                                               aten::to         0.00%      66.411us         0.01%     110.642us     110.642us      71.000us         0.00%     114.000us     114.000us             1  
                                                   Cast         0.00%      44.231us         0.00%      44.231us      44.231us      43.000us         0.00%      43.000us      43.000us             1  
                                      aten::log_softmax         0.00%      43.441us         0.02%     261.164us     261.164us      35.000us         0.00%     255.000us     255.000us             1  
                                     aten::_log_softmax         0.01%     102.071us         0.01%     217.723us     217.723us     102.000us         0.01%     220.000us     220.000us             1  
                                           LogSoftmaxV2         0.01%     115.652us         0.01%     115.652us     115.652us     118.000us         0.01%     118.000us     118.000us             1  
                                         aten::nll_loss         0.00%      40.341us         0.05%     715.120us     715.120us      45.000us         0.00%     714.000us     714.000us             1  
                                 aten::nll_loss_forward         0.01%     187.422us         0.04%     674.779us     674.779us     177.000us         0.01%     669.000us     669.000us             1  
                                             aten::ones         0.00%      45.000us         0.01%     117.291us     117.291us      49.000us         0.00%     122.000us     122.000us             1  
                                               OnesLike         0.00%      72.291us         0.00%      72.291us      72.291us      73.000us         0.00%      73.000us      73.000us             1  
                                                   Cast         0.01%     141.912us         0.01%     141.912us     141.912us     142.000us         0.01%     142.000us     142.000us             1  
                                                NLLLoss         0.02%     228.154us         0.02%     228.154us     228.154us     228.000us         0.02%     228.000us     228.000us             1  
                                        aten::ones_like         0.00%      62.451us         0.01%     132.672us     132.672us      45.000us         0.00%     126.000us     126.000us             1  
                                               OnesLike         0.00%      70.221us         0.00%      70.221us      70.221us      81.000us         0.01%      81.000us      81.000us             1  
                                        NllLossBackward         0.11%       1.601ms         4.43%      66.609ms      66.609ms       1.575ms         0.10%      66.568ms      66.568ms             1  
                                aten::nll_loss_backward         0.02%     308.166us         4.32%      65.007ms      65.007ms     302.000us         0.02%      64.993ms      64.993ms             1  
                                             aten::ones         0.01%      76.751us         0.01%     168.812us     168.812us     103.000us         0.01%     167.000us     167.000us             1  
                                               OnesLike         0.01%      92.061us         0.01%      92.061us      92.061us      64.000us         0.00%      64.000us      64.000us             1  
                                               aten::to         0.01%     106.041us         0.01%     183.412us     183.412us      77.000us         0.01%     150.000us     150.000us             1  
                                                   Cast         0.01%      77.371us         0.01%      77.371us      77.371us      73.000us         0.00%      73.000us      73.000us             1  
                                            NLLLossGrad         4.28%      64.347ms         4.28%      64.347ms      64.347ms      64.374ms         4.28%      64.374ms      64.374ms             1  
                                     LogSoftmaxBackward         0.01%      96.322us         2.92%      44.012ms      44.012ms      89.000us         0.01%      44.003ms      44.003ms             1  
                       aten::_log_softmax_backward_data         0.01%     158.882us         2.92%      43.916ms      43.916ms     127.000us         0.01%      43.914ms      43.914ms             1  
                                         LogSoftmaxGrad         2.91%      43.757ms         2.91%      43.757ms      43.757ms      43.787ms         2.91%      43.787ms      43.787ms             1  
torch::autograd::CppNode<at_npu::native::NPUDtypeCas...         0.01%     139.882us         3.85%      57.936ms      57.936ms     111.000us         0.01%      57.939ms      57.939ms             1  
                                                   Cast         3.84%      57.796ms         3.84%      57.796ms      57.796ms      57.828ms         3.84%      57.828ms      57.828ms             1  
                                          AddmmBackward         0.02%     241.843us        11.70%     176.013ms     176.013ms     229.000us         0.02%     176.007ms     176.007ms             1  
                                                aten::t         0.01%      86.211us         0.01%     152.042us     152.042us      86.000us         0.01%     152.000us     152.000us             1  
                                        aten::transpose         0.00%      39.840us         0.00%      65.831us      65.831us      40.000us         0.00%      66.000us      66.000us             1  
                                       aten::as_strided         0.00%      25.991us         0.00%      25.991us      25.991us      26.000us         0.00%      26.000us      26.000us             1  
                                             aten::conj         0.00%      24.740us         0.00%      24.740us      24.740us      23.000us         0.00%      23.000us      23.000us             1  
                                               aten::mm         0.01%     209.302us         5.54%      83.362ms      83.362ms     158.000us         0.01%      83.367ms      83.367ms             1  
                                                 MatMul         5.53%      83.153ms         5.53%      83.153ms      83.153ms      83.209ms         5.53%      83.209ms      83.209ms             1  
                                                aten::t         0.01%     110.472us         0.01%     196.483us     196.483us     108.000us         0.01%     194.000us     194.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.505s
Self NPU time total: 1.505s

在这里插入图片描述
在这里插入图片描述
通过profile对比,我们可以看到,耗时大的算子时间有了明显的减少,对比耗时情况,在语义相同的情况下替换算子,选择融合算子和优化器,从而减少耗时。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
基于昇腾CANN的卡通图像生成网络可以使用GAN(Generative Adversarial Network)模型来实现。GAN模型由生成器和判别器两部分组成。 生成器的作用是生成卡通图像,它通常由多个卷积层和反卷积层组成。其中卷积层可以提取图像特征,反卷积层可以将提取出的特征还原为图像。生成器的输入通常是一个随机噪声向量,通过不断调整卷积层和反卷积层的参数,生成器可以不断生成更加逼真的卡通图像。 判别器的作用是判断生成的卡通图像是否真实,它通常由多个卷积层和全连接层组成。判别器的输入是一个卡通图像,通过不断调整卷积层和全连接层的参数,判别器可以判断出生成的卡通图像是否逼真。 在训练过程中,生成器和判别器相互对抗,生成器不断生成卡通图像,判别器不断判断卡通图像的真实性。生成器的目标是尽可能生成逼真的卡通图像,判别器的目标是尽可能识别出真实的卡通图像。通过不断的迭代训练,生成器和判别器的性能都会不断提高,最终生成的卡通图像会越来越逼真。 在昇腾CANN上实现这个网络,可以使用MindSpore框架来编写代码。MindSpore框架提供了多种卷积层、反卷积层、全连接层等基本组件,可以方便地搭建深度学习模型。同时,昇腾CANN也提供了高效的硬件加速,可以大幅提升模型训练速度和效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

irrationality

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值