1.prtorch有不同种类的数据类型:32位浮点型,torch.FloatTensor;64位浮点型,torch.DoubleFloatTensor; 16位整形,torch.ShortTensor ; 32位整型,torch.IntTensor ; 64位整型,torch.LongTensor
2.一个torch.autograd.Variable有三个属性:data , grad , grad_fn
对于此函数的理解
import torch
from torch.autograd import Variable
w=Variable(torch.Tesnor([1]),requires_grad=True)
x=Variable(torch.Tesnor([2]),requires_grad=True)
b=Variable(torch.Tensor([3]),requires_grad=True)
其中这个函数会构建一个计算图,是这个变量包含三个部分的内容:data,grad,grad_fn。
再来构造一个函数:
y=w*x+b*b
这个函数就相当于一个完整的计算图,其中w,x,b就是叶节点,而y就是根节点。
y.backward()
接下来就是反向传播,在方向传播的时候会在每个Variable的计算图里,会计算出对于每个变量的梯度,分别给每个变量里的grad,grad_fn进行赋值。其实也是y.backward()决定了求导的方向,以及怎么求导。
print(x.gard)
#会输出1
print(w.grad)
#会输出2
print(b.grad)
#会输出6
3.对于数据的读取和预处理主要用到了,torch.utils.data.dataset这个抽象函数
4.torch.utils.data.dataloader函数读取数据:
from torch.utils.data import dataloader as Dataloader
dataiter=Dataloader(myDataset,batch_size=32,shuffle=True,collate_fn=default_collate)
5.在pytorch所有的层结构和损失函数都来自于torch.nn,所有的模型构建都从这个基类nn.Module继承的,于是有了如下模板:
class net_name(nn.Module):
def __init__(self,other_arguments):
super(net_name,self).__init__()
self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
def forward(self,x)
x=self.conv1(x)
return x
6.torchvision.datasets.ImageFolder是用来进行读取图片的模块。
from torchvision import ImageFolder
dataset=ImageFolder(root='root_path',transform=None,loader=default_loader)
7.优化参数:
import torch.optim as optim
optimizer=torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
在优化之前一般要进行梯度清零:optimizer.zeros(),然后通过loss.backward()反向传播,自动求导每个参数的梯度,最后只需要,optimizer.step()就可以通过梯度作进一步的更新。
在这里着重讲一下三个步骤:1)。optimizer.zeros()========>为了使Variable里面的grad清零
2)。 loss.backward()========>方向传播,其实就是为反过来求导求出每个Variable的grad的值
3)。 optimizer.step()=========>用上一步计算的grad,进行参数更新
8.保存模型:
torch.save( model , ' ./model.pth ')=============>保存整个模型。
torch.save( model.state_dict() , ' ./model_state.pth ') =========>保存模型的状态和参数。
9.加载模型;
load_model=torch.load(' model.pth ')=======>完整的加载比赛
model.load_state_dict( torch.load( 'model_state.pth' ) )==========>加载模型的参数
最为重要的是第二种方法要先定义model的结构,然后导入model的参数,举个例子,model=ResNet14(),定义了model后再载入参数。