日容纷纷无法

<div id="article_content" class="article_content clearfix">
        <link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/ck_htmledit_views-1a85854398.css">
                <div id="content_views" class="markdown_views prism-atom-one-dark">
                    <svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
                        <path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path>
                    </svg>
                    <blockquote>
 <p>pytorch 模型部署很重要的一步是转存pth模型为ONNX,本文记录方法。</p>
</blockquote>
<h3><a name="t0"></a><a id="_onnx_2"></a>转存 onnx</h3>
<ul><li>建立自己的pytorch模型,并加载权重</li></ul>
<pre class="prettyprint"><code class="prism language-python has-numbering" οnclick="mdcp.copyCode(event)" style="position: unset;">model <span class="token operator">=</span> create_model<span class="token punctuation">(</span>num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>load<span class="token punctuation">(</span>model_path<span class="token punctuation">,</span> map_location<span class="token operator">=</span><span class="token string">'cpu'</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token string">"model"</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<div class="hljs-button {2}" data-title="复制" data-report-click="{&quot;spm&quot;:&quot;1001.2101.3001.4259&quot;}"></div></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li></ul></pre>
<ul><li>转存onnx文件</li></ul>
<pre class="prettyprint"><code class="prism language-python has-numbering" οnclick="mdcp.copyCode(event)" style="position: unset;">dummy_input <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> device<span class="token operator">=</span><span class="token string">'cpu'</span><span class="token punctuation">)</span>
torch<span class="token punctuation">.</span>onnx<span class="token punctuation">.</span>_export<span class="token punctuation">(</span>model<span class="token punctuation">,</span> dummy_input<span class="token punctuation">,</span> <span class="token string">"faster_rcnn.onnx"</span><span class="token punctuation">,</span> verbose<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> opset_version<span class="token operator">=</span><span class="token number">11</span><span class="token punctuation">)</span>
<div class="hljs-button {2}" data-title="复制" data-report-click="{&quot;spm&quot;:&quot;1001.2101.3001.4259&quot;}"></div></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li></ul></pre>
<blockquote>
 <p>将模型保存在了当前目录的 <code>faster_rcnn.onnx</code>文件内</p>
</blockquote>
<h4><a id="_onnx__20"></a>验证 onnx 有效性</h4>
<ul><li>安装 <code>onnxruntime</code> 库</li></ul>
<pre class="prettyprint"><code class="has-numbering" οnclick="mdcp.copyCode(event)" style="position: unset;">pip install onnxruntime
<div class="hljs-button {2}" data-title="复制" data-report-click="{&quot;spm&quot;:&quot;1001.2101.3001.4259&quot;}"></div></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li></ul></pre>
<ul><li>加载onnx模型并测试</li></ul>
<pre class="prettyprint"><code class="prism language-python has-numbering" οnclick="mdcp.copyCode(event)" style="position: unset;"><span class="token keyword">import</span> onnxruntime
<span class="token keyword">from</span> onnxruntime<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> get_example

<span class="token keyword">def</span> <span class="token function">to_numpy</span><span class="token punctuation">(</span>tensor<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">return</span> tensor<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> tensor<span class="token punctuation">.</span>requires_grad <span class="token keyword">else</span> tensor<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token comment"># 测试数据</span>
dummy_input <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> device<span class="token operator">=</span><span class="token string">'cpu'</span><span class="token punctuation">)</span>

example_model <span class="token operator">=</span> get_example<span class="token punctuation">(</span><span class="token operator">&lt;</span>absolute_root_to_your_onnx_model_file<span class="token operator">&gt;</span><span class="token punctuation">)</span>
<span class="token comment"># netron.start(example_model) 使用 netron python 包可视化网络</span>
sess <span class="token operator">=</span> onnxruntime<span class="token punctuation">.</span>InferenceSession<span class="token punctuation">(</span>example_model<span class="token punctuation">)</span>

<span class="token comment"># onnx 网络输出</span>
onnx_out <span class="token operator">=</span> sess<span class="token punctuation">.</span>run<span class="token punctuation">(</span><span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token punctuation">{<!-- --></span><span class="token operator">&lt;</span>input_layer_name_of_your_network<span class="token operator">&gt;</span><span class="token punctuation">:</span> to_numpy<span class="token punctuation">(</span>dummy_input<span class="token punctuation">)</span><span class="token punctuation">}</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>onnx_out<span class="token punctuation">)</span>

model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">with</span> torch<span class="token punctuation">.</span>no_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># pytorch model 网络输出</span>
    torch_out <span class="token operator">=</span> model<span class="token punctuation">(</span>dummy_input<span class="token punctuation">)</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span>torch_out<span class="token punctuation">)</span>
<div class="hljs-button {2}" data-title="复制" data-report-click="{&quot;spm&quot;:&quot;1001.2101.3001.4259&quot;}"></div></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li></ul></pre>
<ul><li>输出:</li></ul>
<pre class="prettyprint"><code class="has-numbering" οnclick="mdcp.copyCode(event)" style="position: unset;">onnx_out
[array([[  0.       ,  93.246    , 228.95842  , 256.       ],
       [  0.       ,   2.6370468, 209.39705  , 148.17822  ]],
      dtype=float32), array([1, 1], dtype=int64), array([0.1501071 , 0.07568519], dtype=float32)]

torch_out
[{'boxes': tensor([[  0.0000,  93.2459, 228.9584, 256.0000],
        [  0.0000,   2.6370, 209.3971, 148.1782]]), 'labels': tensor([1, 1]), 'scores': tensor([0.1501, 0.0757])}]

<div class="hljs-button {2}" data-title="复制" data-report-click="{&quot;spm&quot;:&quot;1001.2101.3001.4259&quot;}"></div></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li></ul></pre>
<h4><a id="_69"></a>获取自己网络输入层名称</h4>
<ul><li>有时对网络不熟悉的情况下不清楚模型输入层的名称,可以使用Netron可视化自己的网络,获取输入层名称,喂入onnx的sess中。</li></ul>
<h2><a name="t1"></a><a id="font_colorred_size6_face_font_73"></a><font size="6" face="微软雅黑" color="red">注意 !!!</font></h2>
<ul><li> <p>pytorch 模型在转 ONNX 模型的过程中,使用的导出器是一个基于轨迹的导出器,这意味着它执行时需要运行一次模型,然后导出实际参与运算的运算符. 这也意味着, 如果你的模型是动态的,例如,改变一些依赖于输入数据的操作,这时的导出结果是不准确的.同样,一 个轨迹可能只对一个具体的输入尺寸有效 (这是为什么我们在轨迹中需要有明确的输入的原因之一.) 我们建议检查 模型的轨迹,确保被追踪的运算符是合理的. ——— <a href="https://pytorch.apachecn.org/docs/0.3/onnx.html">pytorch 文档</a></p> </li><li> <p>也就是说,如果网络模块中存在 if… else… 类似的分支,在生成ONNX模型时会依据所使用的初始数据来选择其中某一个分支,这样所生成的ONNX模型仅会保留这一个分支的结构,在原始pytorch模型中的其他逻辑能力在该模型中不复存在。</p> </li></ul>
<h3><a name="t2"></a><a id="_79"></a>参考资料</h3>
<ul><li>https://pytorch.apachecn.org/docs/0.3/onnx.html</li><li>https://www.jianshu.com/p/5a0a09fbdeba</li></ul>
                </div><div data-report-view="{&quot;mod&quot;:&quot;1585297308_001&quot;,&quot;spm&quot;:&quot;1001.2101.3001.6548&quot;,&quot;dest&quot;:&quot;https://blog.csdn.net/zywvvd/article/details/109875315&quot;,&quot;extend1&quot;:&quot;pc&quot;,&quot;ab&quot;:&quot;new&quot;}"><div></div></div>
                <link href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/markdown_views-d7a94ec6ab.css" rel="stylesheet">
                <link href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/style-49037e4d27.css" rel="stylesheet">
        </div>

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值