<div id="article_content" class="article_content csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
<div class="markdown_views" deep="7">
<h3 id="本节对pytorch的参考资料以及相关内容进行总结"><a name="t0" target="_blank"></a>本节对pytorch的参考资料以及相关内容进行总结</h3>
<h4 id="参考资料1-pytorch-tutorial"><a name="t1" target="_blank"></a><a href="https://github.com/yunjey/pytorch-tutorial" target="_blank">参考资料1 pytorch tutorial</a></h4>
<h4 id="参考资料2-莫烦tutorial"><a name="t2" target="_blank"></a><a href="https://morvanzhou.github.io/tutorials/machine-learning/torch/" target="_blank">参考资料2 莫烦tutorial</a></h4>
<h4 id="参考资料3-pytorch官方文档"><a name="t3" target="_blank"></a><a href="http://pytorch.org/docs/0.2.0/notes/autograd.html" target="_blank">参考资料3 pytorch官方文档</a></h4>
<h4 id="参考资料4-pytorch中文文档"><a name="t4" target="_blank"></a><a href="https://pytorch-cn.readthedocs.io/zh/latest/#pytorch" target="_blank">参考资料4 pytorch中文文档</a></h4>
<h3 id="安装"><a name="t5" target="_blank"></a>安装</h3>
<h4 id="参考传送门我装的是cuda75的版本注意不支持window下的安装"><a name="t6" target="_blank"></a>参考<a href="http://blog.csdn.net/zeroqiaoba/article/details/75192025" target="_blank">传送门</a>,我装的是cuda7.5的版本(注意:不支持window下的安装)</h4>
<h3 id="注意要点"><a name="t7" target="_blank"></a>注意要点</h3>
<ul>
<li><p>如果你要保存map,请务必保存成numpy格式的数据,因为torch格式的数据非常占用内存资源,这是一个很神奇的现象。不然会内存炸裂的。 <br data-filtered="filtered">
<img src="https://img-blog.csdn.net/20171125195907119?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvemVyb1FpYW9iYQ==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast" alt="这里写图片描述" title=""></p></li>
<li><p>pytorch解除占用时,请用</p></li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs lasso has-numbering">fuser <span class="hljs-attribute">-v</span> /dev/nvidia<span class="hljs-subst">*</span> 查看所有进程 </code><ul class="pre-numbering"><li>1</li></ul></pre>
<ul>
<li>pytorch保存模型时,请先转到cpu下,否则会out of memory</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs avrasm has-numbering">model<span class="hljs-preprocessor">.cpu</span>()
torch<span class="hljs-preprocessor">.save</span>(model<span class="hljs-preprocessor">.state</span>_dict(), modelPath)
model<span class="hljs-preprocessor">.cuda</span>(args<span class="hljs-preprocessor">.cuda</span>)</code><ul class="pre-numbering"><li>1</li><li>2</li><li>3</li></ul></pre>
<ul>
<li>pytorch加载模型到cpu上,请使用</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs avrasm has-numbering">model<span class="hljs-preprocessor">.load</span>_state_dict(torch<span class="hljs-preprocessor">.load</span>(model_file, map_location=lambda storage, loc: storage))</code><ul class="pre-numbering"><li>1</li></ul></pre>
<ul>
<li>pytorch对模型验证,请使用</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs php has-numbering">model.<span class="hljs-keyword">eval</span>()</code><ul class="pre-numbering"><li>1</li></ul></pre>
<h3 id="基本语法"><a name="t8" target="_blank"></a>基本语法</h3>
<pre class="prettyprint" name="code"><code class="hljs livecodeserver has-numbering"><span class="hljs-comment">#删除网络最后一列</span>
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-<span class="hljs-number">1</span>] <span class="hljs-comment"># delete the last fc layer.</span>
self.resnet = nn.Sequential(*modules) <span class="hljs-comment"># 从list转model</span>
<span class="hljs-comment">#定义类时,class modelT_1(object)和modelT_1(nn.Module)是不一样的。一个可以直接用net(input)调用forward里面的函数,另一个必须用net.forward(input)才能调用forward里面的函数。否则会报错:not callable</span>
<span class="hljs-comment">#torch和np.array()之间的转换</span>
torch_data = torch.from_numpy(np_data) <span class="hljs-comment"># np.array -> tensor</span>
tensor2array = torch_data.numpy() <span class="hljs-comment"># tensor -> array</span>
<span class="hljs-built_in">variable</span>.data() <span class="hljs-comment"># variable -> tensor</span>
<span class="hljs-built_in">variable</span>.data[<span class="hljs-number">0</span>] <span class="hljs-comment"># variable -> number</span>
tensor = torch.FloatTensor(data) <span class="hljs-comment"># 转成float类型</span>
<span class="hljs-comment">#基本操作</span>
torch.mm(tensor, tensor) <span class="hljs-comment"># 矩阵内积</span>
tensor.dot(tensor) <span class="hljs-comment"># 必须是一维输入才可以处理</span>
<span class="hljs-comment">#变量定义</span>
<span class="hljs-built_in">from</span> torch.autograd import Variable
tensor = torch.FloatTensor([[<span class="hljs-number">1</span>,<span class="hljs-number">2</span>],[<span class="hljs-number">3</span>,<span class="hljs-number">4</span>]])
<span class="hljs-built_in">variable</span> = Variable(tensor, requires_grad=True) <span class="hljs-comment"># 参与误差反向传播</span>
<span class="hljs-comment">#要搭建一个计算图 computational graph 进行整体的误差反向传播,需要利用variable进行搭建。计算误差反向传播时,计算的梯度是边上的梯度,不是结点的梯度,结点的误差需要需要利用链式法则继续传递到之前的结点。</span>
v_out = torch.mean(<span class="hljs-built_in">variable</span>*<span class="hljs-built_in">variable</span>)<span class="hljs-comment"># 逐个元素平方</span>
v_out.backward() <span class="hljs-comment"># 计算关于variabled的梯度</span>
<span class="hljs-comment"># 获取batch是根据的输入数据,一般这种加载数据的方式只用在用train得到batch的数据上才可以。一般输入tensor形式的lable和data,读入也是tensor形式的,只有在训练时候才转化成variable,有一点需要注意,torch中的cnn,通道在前面,</span>
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
loader = Data.DataLoader(
dataset=torch_dataset, <span class="hljs-comment"># torch TensorDataset format</span>
batch_size=BATCH_SIZE, <span class="hljs-comment"># mini batch size</span>
shuffle=True, <span class="hljs-comment"># random shuffle for training</span>
num_workers=<span class="hljs-number">2</span>, <span class="hljs-comment"># subprocesses for loading data</span>
)
<span class="hljs-keyword">for</span> epoch <span class="hljs-operator">in</span> range(<span class="hljs-number">3</span>): <span class="hljs-comment"># train entire dataset 3 times</span>
<span class="hljs-keyword">for</span> step, (batch_x, batch_y) <span class="hljs-operator">in</span> enumerate(loader): <span class="hljs-comment"># for each training step</span>
<span class="hljs-comment"># zip函数,用于得到类似tuple(a,b,c)</span>
<span class="hljs-comment"># volatile=True是让test数据不参与梯度的计算,加速测试;test的输入是一个tensor,train的输入由于要梯度反传,所以应该是一个variable.</span>
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=<span class="hljs-number">1</span>), volatile=True).type(torch.FloatTensor)[:<span class="hljs-number">2000</span>]/<span class="hljs-number">255.</span>
<span class="hljs-comment"># 返回[batch,1,a,b]大小的tensor</span>
enumerate(train_loader) <span class="hljs-comment">#这里才会执行enumrate时候的normalize</span>
<span class="hljs-keyword">for</span> images, labels <span class="hljs-operator">in</span> test_loader <span class="hljs-comment"># 就可以读取一个元组</span>
x.view(-<span class="hljs-number">1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>),就是resize函数
<span class="hljs-comment"># 测试方法</span>
print xx.size()
break 进入torch网络内部
type(<span class="hljs-built_in">param</span>.data) <span class="hljs-comment"># 显示模型类型</span>
<span class="hljs-comment"># 损失,需要a,b维度相同</span>
nn.MSELoss(<span class="hljs-operator">a</span>,b)
<span class="hljs-comment"># 连续预测,将上一步的hidden状态输入到下一步中。pytorch的图是动态的,rnn的输入时刻可以是动态的。</span>
h_state = Variable(h_state.data)
<span class="hljs-comment">#看懂了我的第一个代码,gan和cgan,大致对其有了了解,gan就是一个多目标优化问题,这种类型的代码都可以参照它完成。</span>
<span class="hljs-comment">#使用gpu,就是加入.cude(),包括三部分,tensor,model以及运算。</span>
<span class="hljs-comment">#dropout层的加入很简单,</span>
torch.nn.Dropout(<span class="hljs-number">0.5</span>) <span class="hljs-comment"># drop 50% of the neuron</span>
测试的时候需要转成net_overfitting.eval(),得到输出test_pred_ofit = net_overfitting(test_x)。训练的时候要转成net_dropped.train(),两种模式切换。
<span class="hljs-comment">#Batch normalization也需要设定net.eval()和net.train()两种方式。BatchNorm2d(),对于卷积,设置为channel个数。每个batch计算均值和方差。常见结构:卷积+bn+relu+pool</span>
!!注意eval()只在bn和dropout上有用!!
<span class="hljs-comment"># 使用预训练的模型结构</span>
resnet = torchvision.models.resnet18(pretrained=True)<span class="hljs-comment"># 挺好用的,下载时还会显示下载进度。</span>
<span class="hljs-comment"># 冻结模型一些层</span>
<span class="hljs-keyword">for</span> <span class="hljs-built_in">param</span> <span class="hljs-operator">in</span> resnet.parameters(): <span class="hljs-built_in">param</span>.requires_grad = False
<span class="hljs-comment"># 替换开头几层,例如将fc层替换为新的fc。直接定义的层,in_features和out_features给出单元数目</span>
resnet.fc = nn.Linear(resnet.fc.in_features, <span class="hljs-number">100</span>)
模型训练和预测都用了<span class="hljs-built_in">variable</span>类型
<span class="hljs-comment"># 保存建议保存参数</span>
<span class="hljs-comment"># Save and load the entire model.</span>
torch.save(resnet, <span class="hljs-string">'model.pkl'</span>)
model = torch.<span class="hljs-built_in">load</span>(<span class="hljs-string">'model.pkl'</span>)
<span class="hljs-comment"># Save and load only the model parameters(recommended).</span>
torch.save(resnet.state_dict(), <span class="hljs-string">'params.pkl'</span>)
resnet.load_state_dict(torch.<span class="hljs-built_in">load</span>(<span class="hljs-string">'params.pkl'</span>))
读入数据的格式是[batch,channel,height,width],可以用view进行reshape,
torch.<span class="hljs-built_in">max</span> 返回两个值,一个是位置,一个是值
<span class="hljs-comment">#训练过程中,对图片增强,增加一些旋转,裁剪,在测试时,一般不做裁剪处理。多个图结构之间可以相互调用,具体参考残差网络的连接方式。</span>
<span class="hljs-comment"># 优化器可以在训练过程中修改,比如修改学习率等等,比较自由</span>
</code><ul class="pre-numbering"><li>1</li><li>2</li><li>3</li><li>4</li><li>5</li><li>6</li><li>7</li><li>8</li><li>9</li><li>10</li><li>11</li><li>12</li><li>13</li><li>14</li><li>15</li><li>16</li><li>17</li><li>18</li><li>19</li><li>20</li><li>21</li><li>22</li><li>23</li><li>24</li><li>25</li><li>26</li><li>27</li><li>28</li><li>29</li><li>30</li><li>31</li><li>32</li><li>33</li><li>34</li><li>35</li><li>36</li><li>37</li><li>38</li><li>39</li><li>40</li><li>41</li><li>42</li><li>43</li><li>44</li><li>45</li><li>46</li><li>47</li><li>48</li><li>49</li><li>50</li><li>51</li><li>52</li><li>53</li><li>54</li><li>55</li><li>56</li><li>57</li><li>58</li><li>59</li><li>60</li><li>61</li><li>62</li><li>63</li><li>64</li><li>65</li><li>66</li><li>67</li><li>68</li><li>69</li><li>70</li><li>71</li><li>72</li><li>73</li><li>74</li><li>75</li><li>76</li><li>77</li><li>78</li><li>79</li><li>80</li><li>81</li><li>82</li><li>83</li><li>84</li><li>85</li><li>86</li><li>87</li><li>88</li><li>89</li><li>90</li><li>91</li><li>92</li><li>93</li><li>94</li><li>95</li><li>96</li><li>97</li><li>98</li><li>99</li><li>100</li><li>101</li></ul></pre> </div>
<link rel="stylesheet" href="http://csdnimg.cn/release/phoenix/production/markdown_views-68a8aad09e.css">
<script>
$(".MathJax").remove();
</script>
<script type="text/javascript" src="//static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
</div>
<div class="markdown_views" deep="7">
<h3 id="本节对pytorch的参考资料以及相关内容进行总结"><a name="t0" target="_blank"></a>本节对pytorch的参考资料以及相关内容进行总结</h3>
<h4 id="参考资料1-pytorch-tutorial"><a name="t1" target="_blank"></a><a href="https://github.com/yunjey/pytorch-tutorial" target="_blank">参考资料1 pytorch tutorial</a></h4>
<h4 id="参考资料2-莫烦tutorial"><a name="t2" target="_blank"></a><a href="https://morvanzhou.github.io/tutorials/machine-learning/torch/" target="_blank">参考资料2 莫烦tutorial</a></h4>
<h4 id="参考资料3-pytorch官方文档"><a name="t3" target="_blank"></a><a href="http://pytorch.org/docs/0.2.0/notes/autograd.html" target="_blank">参考资料3 pytorch官方文档</a></h4>
<h4 id="参考资料4-pytorch中文文档"><a name="t4" target="_blank"></a><a href="https://pytorch-cn.readthedocs.io/zh/latest/#pytorch" target="_blank">参考资料4 pytorch中文文档</a></h4>
<h3 id="安装"><a name="t5" target="_blank"></a>安装</h3>
<h4 id="参考传送门我装的是cuda75的版本注意不支持window下的安装"><a name="t6" target="_blank"></a>参考<a href="http://blog.csdn.net/zeroqiaoba/article/details/75192025" target="_blank">传送门</a>,我装的是cuda7.5的版本(注意:不支持window下的安装)</h4>
<h3 id="注意要点"><a name="t7" target="_blank"></a>注意要点</h3>
<ul>
<li><p>如果你要保存map,请务必保存成numpy格式的数据,因为torch格式的数据非常占用内存资源,这是一个很神奇的现象。不然会内存炸裂的。 <br data-filtered="filtered">
<img src="https://img-blog.csdn.net/20171125195907119?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvemVyb1FpYW9iYQ==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast" alt="这里写图片描述" title=""></p></li>
<li><p>pytorch解除占用时,请用</p></li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs lasso has-numbering">fuser <span class="hljs-attribute">-v</span> /dev/nvidia<span class="hljs-subst">*</span> 查看所有进程 </code><ul class="pre-numbering"><li>1</li></ul></pre>
<ul>
<li>pytorch保存模型时,请先转到cpu下,否则会out of memory</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs avrasm has-numbering">model<span class="hljs-preprocessor">.cpu</span>()
torch<span class="hljs-preprocessor">.save</span>(model<span class="hljs-preprocessor">.state</span>_dict(), modelPath)
model<span class="hljs-preprocessor">.cuda</span>(args<span class="hljs-preprocessor">.cuda</span>)</code><ul class="pre-numbering"><li>1</li><li>2</li><li>3</li></ul></pre>
<ul>
<li>pytorch加载模型到cpu上,请使用</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs avrasm has-numbering">model<span class="hljs-preprocessor">.load</span>_state_dict(torch<span class="hljs-preprocessor">.load</span>(model_file, map_location=lambda storage, loc: storage))</code><ul class="pre-numbering"><li>1</li></ul></pre>
<ul>
<li>pytorch对模型验证,请使用</li>
</ul>
<pre class="prettyprint" name="code"><code class="hljs php has-numbering">model.<span class="hljs-keyword">eval</span>()</code><ul class="pre-numbering"><li>1</li></ul></pre>
<h3 id="基本语法"><a name="t8" target="_blank"></a>基本语法</h3>
<pre class="prettyprint" name="code"><code class="hljs livecodeserver has-numbering"><span class="hljs-comment">#删除网络最后一列</span>
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-<span class="hljs-number">1</span>] <span class="hljs-comment"># delete the last fc layer.</span>
self.resnet = nn.Sequential(*modules) <span class="hljs-comment"># 从list转model</span>
<span class="hljs-comment">#定义类时,class modelT_1(object)和modelT_1(nn.Module)是不一样的。一个可以直接用net(input)调用forward里面的函数,另一个必须用net.forward(input)才能调用forward里面的函数。否则会报错:not callable</span>
<span class="hljs-comment">#torch和np.array()之间的转换</span>
torch_data = torch.from_numpy(np_data) <span class="hljs-comment"># np.array -> tensor</span>
tensor2array = torch_data.numpy() <span class="hljs-comment"># tensor -> array</span>
<span class="hljs-built_in">variable</span>.data() <span class="hljs-comment"># variable -> tensor</span>
<span class="hljs-built_in">variable</span>.data[<span class="hljs-number">0</span>] <span class="hljs-comment"># variable -> number</span>
tensor = torch.FloatTensor(data) <span class="hljs-comment"># 转成float类型</span>
<span class="hljs-comment">#基本操作</span>
torch.mm(tensor, tensor) <span class="hljs-comment"># 矩阵内积</span>
tensor.dot(tensor) <span class="hljs-comment"># 必须是一维输入才可以处理</span>
<span class="hljs-comment">#变量定义</span>
<span class="hljs-built_in">from</span> torch.autograd import Variable
tensor = torch.FloatTensor([[<span class="hljs-number">1</span>,<span class="hljs-number">2</span>],[<span class="hljs-number">3</span>,<span class="hljs-number">4</span>]])
<span class="hljs-built_in">variable</span> = Variable(tensor, requires_grad=True) <span class="hljs-comment"># 参与误差反向传播</span>
<span class="hljs-comment">#要搭建一个计算图 computational graph 进行整体的误差反向传播,需要利用variable进行搭建。计算误差反向传播时,计算的梯度是边上的梯度,不是结点的梯度,结点的误差需要需要利用链式法则继续传递到之前的结点。</span>
v_out = torch.mean(<span class="hljs-built_in">variable</span>*<span class="hljs-built_in">variable</span>)<span class="hljs-comment"># 逐个元素平方</span>
v_out.backward() <span class="hljs-comment"># 计算关于variabled的梯度</span>
<span class="hljs-comment"># 获取batch是根据的输入数据,一般这种加载数据的方式只用在用train得到batch的数据上才可以。一般输入tensor形式的lable和data,读入也是tensor形式的,只有在训练时候才转化成variable,有一点需要注意,torch中的cnn,通道在前面,</span>
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
loader = Data.DataLoader(
dataset=torch_dataset, <span class="hljs-comment"># torch TensorDataset format</span>
batch_size=BATCH_SIZE, <span class="hljs-comment"># mini batch size</span>
shuffle=True, <span class="hljs-comment"># random shuffle for training</span>
num_workers=<span class="hljs-number">2</span>, <span class="hljs-comment"># subprocesses for loading data</span>
)
<span class="hljs-keyword">for</span> epoch <span class="hljs-operator">in</span> range(<span class="hljs-number">3</span>): <span class="hljs-comment"># train entire dataset 3 times</span>
<span class="hljs-keyword">for</span> step, (batch_x, batch_y) <span class="hljs-operator">in</span> enumerate(loader): <span class="hljs-comment"># for each training step</span>
<span class="hljs-comment"># zip函数,用于得到类似tuple(a,b,c)</span>
<span class="hljs-comment"># volatile=True是让test数据不参与梯度的计算,加速测试;test的输入是一个tensor,train的输入由于要梯度反传,所以应该是一个variable.</span>
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=<span class="hljs-number">1</span>), volatile=True).type(torch.FloatTensor)[:<span class="hljs-number">2000</span>]/<span class="hljs-number">255.</span>
<span class="hljs-comment"># 返回[batch,1,a,b]大小的tensor</span>
enumerate(train_loader) <span class="hljs-comment">#这里才会执行enumrate时候的normalize</span>
<span class="hljs-keyword">for</span> images, labels <span class="hljs-operator">in</span> test_loader <span class="hljs-comment"># 就可以读取一个元组</span>
x.view(-<span class="hljs-number">1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>),就是resize函数
<span class="hljs-comment"># 测试方法</span>
print xx.size()
break 进入torch网络内部
type(<span class="hljs-built_in">param</span>.data) <span class="hljs-comment"># 显示模型类型</span>
<span class="hljs-comment"># 损失,需要a,b维度相同</span>
nn.MSELoss(<span class="hljs-operator">a</span>,b)
<span class="hljs-comment"># 连续预测,将上一步的hidden状态输入到下一步中。pytorch的图是动态的,rnn的输入时刻可以是动态的。</span>
h_state = Variable(h_state.data)
<span class="hljs-comment">#看懂了我的第一个代码,gan和cgan,大致对其有了了解,gan就是一个多目标优化问题,这种类型的代码都可以参照它完成。</span>
<span class="hljs-comment">#使用gpu,就是加入.cude(),包括三部分,tensor,model以及运算。</span>
<span class="hljs-comment">#dropout层的加入很简单,</span>
torch.nn.Dropout(<span class="hljs-number">0.5</span>) <span class="hljs-comment"># drop 50% of the neuron</span>
测试的时候需要转成net_overfitting.eval(),得到输出test_pred_ofit = net_overfitting(test_x)。训练的时候要转成net_dropped.train(),两种模式切换。
<span class="hljs-comment">#Batch normalization也需要设定net.eval()和net.train()两种方式。BatchNorm2d(),对于卷积,设置为channel个数。每个batch计算均值和方差。常见结构:卷积+bn+relu+pool</span>
!!注意eval()只在bn和dropout上有用!!
<span class="hljs-comment"># 使用预训练的模型结构</span>
resnet = torchvision.models.resnet18(pretrained=True)<span class="hljs-comment"># 挺好用的,下载时还会显示下载进度。</span>
<span class="hljs-comment"># 冻结模型一些层</span>
<span class="hljs-keyword">for</span> <span class="hljs-built_in">param</span> <span class="hljs-operator">in</span> resnet.parameters(): <span class="hljs-built_in">param</span>.requires_grad = False
<span class="hljs-comment"># 替换开头几层,例如将fc层替换为新的fc。直接定义的层,in_features和out_features给出单元数目</span>
resnet.fc = nn.Linear(resnet.fc.in_features, <span class="hljs-number">100</span>)
模型训练和预测都用了<span class="hljs-built_in">variable</span>类型
<span class="hljs-comment"># 保存建议保存参数</span>
<span class="hljs-comment"># Save and load the entire model.</span>
torch.save(resnet, <span class="hljs-string">'model.pkl'</span>)
model = torch.<span class="hljs-built_in">load</span>(<span class="hljs-string">'model.pkl'</span>)
<span class="hljs-comment"># Save and load only the model parameters(recommended).</span>
torch.save(resnet.state_dict(), <span class="hljs-string">'params.pkl'</span>)
resnet.load_state_dict(torch.<span class="hljs-built_in">load</span>(<span class="hljs-string">'params.pkl'</span>))
读入数据的格式是[batch,channel,height,width],可以用view进行reshape,
torch.<span class="hljs-built_in">max</span> 返回两个值,一个是位置,一个是值
<span class="hljs-comment">#训练过程中,对图片增强,增加一些旋转,裁剪,在测试时,一般不做裁剪处理。多个图结构之间可以相互调用,具体参考残差网络的连接方式。</span>
<span class="hljs-comment"># 优化器可以在训练过程中修改,比如修改学习率等等,比较自由</span>
</code><ul class="pre-numbering"><li>1</li><li>2</li><li>3</li><li>4</li><li>5</li><li>6</li><li>7</li><li>8</li><li>9</li><li>10</li><li>11</li><li>12</li><li>13</li><li>14</li><li>15</li><li>16</li><li>17</li><li>18</li><li>19</li><li>20</li><li>21</li><li>22</li><li>23</li><li>24</li><li>25</li><li>26</li><li>27</li><li>28</li><li>29</li><li>30</li><li>31</li><li>32</li><li>33</li><li>34</li><li>35</li><li>36</li><li>37</li><li>38</li><li>39</li><li>40</li><li>41</li><li>42</li><li>43</li><li>44</li><li>45</li><li>46</li><li>47</li><li>48</li><li>49</li><li>50</li><li>51</li><li>52</li><li>53</li><li>54</li><li>55</li><li>56</li><li>57</li><li>58</li><li>59</li><li>60</li><li>61</li><li>62</li><li>63</li><li>64</li><li>65</li><li>66</li><li>67</li><li>68</li><li>69</li><li>70</li><li>71</li><li>72</li><li>73</li><li>74</li><li>75</li><li>76</li><li>77</li><li>78</li><li>79</li><li>80</li><li>81</li><li>82</li><li>83</li><li>84</li><li>85</li><li>86</li><li>87</li><li>88</li><li>89</li><li>90</li><li>91</li><li>92</li><li>93</li><li>94</li><li>95</li><li>96</li><li>97</li><li>98</li><li>99</li><li>100</li><li>101</li></ul></pre> </div>
<link rel="stylesheet" href="http://csdnimg.cn/release/phoenix/production/markdown_views-68a8aad09e.css">
<script>
$(".MathJax").remove();
</script>
<script type="text/javascript" src="//static.blog.csdn.net/mdeditor/public/res/bower-libs/MathJax/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
</div>