pytorch学习总结

1.prtorch有不同种类的数据类型:32位浮点型,torch.FloatTensor;64位浮点型,torch.DoubleFloatTensor; 16位整形,torch.ShortTensor ; 32位整型,torch.IntTensor ; 64位整型,torch.LongTensor

2.一个torch.autograd.Variable有三个属性:data , grad , grad_fn

  对于此函数的理解

import torch
from torch.autograd import Variable

w=Variable(torch.Tesnor([1]),requires_grad=True)
x=Variable(torch.Tesnor([2]),requires_grad=True)
b=Variable(torch.Tensor([3]),requires_grad=True)

其中这个函数会构建一个计算图,是这个变量包含三个部分的内容:data,grad,grad_fn。

再来构造一个函数:

y=w*x+b*b

这个函数就相当于一个完整的计算图,其中w,x,b就是叶节点,而y就是根节点。

y.backward()

接下来就是反向传播,在方向传播的时候会在每个Variable的计算图里,会计算出对于每个变量的梯度,分别给每个变量里的grad,grad_fn进行赋值。其实也是y.backward()决定了求导的方向,以及怎么求导。

print(x.gard)
#会输出1
print(w.grad)
#会输出2
print(b.grad)
#会输出6

3.对于数据的读取和预处理主要用到了,torch.utils.data.dataset这个抽象函数

4.torch.utils.data.dataloader函数读取数据:

from torch.utils.data import dataloader as Dataloader
dataiter=Dataloader(myDataset,batch_size=32,shuffle=True,collate_fn=default_collate)

5.在pytorch所有的层结构和损失函数都来自于torch.nn,所有的模型构建都从这个基类nn.Module继承的,于是有了如下模板:

class net_name(nn.Module):
  def __init__(self,other_arguments):
    super(net_name,self).__init__()
    
    self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)

  def forward(self,x)
    x=self.conv1(x)
    return x

6.torchvision.datasets.ImageFolder是用来进行读取图片的模块。

from torchvision import ImageFolder
dataset=ImageFolder(root='root_path',transform=None,loader=default_loader)

7.优化参数:

import torch.optim as optim

optimizer=torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

在优化之前一般要进行梯度清零:optimizer.zeros(),然后通过loss.backward()反向传播,自动求导每个参数的梯度,最后只需要,optimizer.step()就可以通过梯度作进一步的更新。

在这里着重讲一下三个步骤:1)。optimizer.zeros()========>为了使Variable里面的grad清零

                                               2)。 loss.backward()========>方向传播,其实就是为反过来求导求出每个Variable的grad的值

                                              3)。 optimizer.step()=========>用上一步计算的grad,进行参数更新

8.保存模型:

  torch.save( model , ' ./model.pth ')=============>保存整个模型。

  torch.save( model.state_dict() , ' ./model_state.pth ') =========>保存模型的状态和参数。

9.加载模型;

  load_model=torch.load(' model.pth ')=======>完整的加载比赛

  model.load_state_dict( torch.load( 'model_state.pth' ) )==========>加载模型的参数

最为重要的是第二种方法要先定义model的结构,然后导入model的参数,举个例子,model=ResNet14(),定义了model后再载入参数。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值