1. Tensor(张量)
pytorch中的张量与numpy的ndarray相对应,都表示一个多维的矩阵,张量与ndarray可以相互转化。
Tensor有五种数据类型:32位浮点型torch.FloatTensor、64位浮点型torch.DoubleTensor、16位整型torch.ShortTensor、32位整型torch.IntTensor、64位整型torch.LongTensor。
可以通过以下方式定义张量:
a = torch.Tensor([[1,2],[2,3],[3,4]])#默认是32位浮点型
b = torch.LongTensor([[1,2],[2,3],[3,4]])#长整型
c = torch.zeros((3,2))#3*2的零矩阵
d = torch.randn((3,3))#3*3的随机矩阵
Tensor可以像numpy一样用索引找到元素,也可以进行转化。
a[0,1] = 100 #设置第一行第二列的值为100
numpy_b = b.numpy()#Tensor转化为numpy
e = numpy.array([[1,1],[2,2])
torch_e = torch.from_numpy(e)#numpy转化为Tensor
f_torch_e = torch_e.float()#改变数据类别
2. Variable(变量)
变量是神经网络计算图中的一个概念,Variable提供了一个自动求导功能,变量与张量本质上没有什么不同,但变量会被放入一个计算图中,然后进行自动求导。
Variable在torch.autograd.Variable中,一个Tensor转化为Variable可以直接使用Variable()函数。
Variable有三个属性data,grad和grad_fn。
data可以将Variable中的tensor数值取出,grad_fn代表着得到此变量的操作,grad_fn是这个Variable的反向传播梯度。
Variable反向传播,获得梯度函数如下(ps:看了几遍没看懂,懂了再回来改吧)
x = torch.randn([1])
x = Variable(x, requires_gard=True)
x.backward()
3. Dataset(数据集)
数据集是任何机器学习都需要的,PyTorch中可以定义通过继承抽象类torch.utils.data.Dataset来实现自己的数据集。其中必须要定义__len__()和__getitem__()两个函数。
class myDataset(Dataset):
def __init__(self, csv_file, txt_file, root_dir, other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file, 'r') as f:
datalist = f.readline()
self.txt_data = datalist
self.root_dir = root_dir
def __len__(self):#返回数据集的大小
return len(self.csv_data)
def __getitem__(self, idx):#根据下标返回数据
data = (self.csv_data[]idx], self.txt_data[idx])
return data
PyTorch还可以使用loader来进行多线程的读取数据,通过torch.utils.data.DataLoader来定义。
4. nn.Module(模组)
PyTorch中建立神经网络,所有的层结构与损失函数都来自于torch.nn,所有的模型都要继承这个基类nn.Module。
class net_name(nn.Module):
def __init__(self, other_arguments):
super(net_name,self).__init__()
self.convl = nn.Conv2d(in_channels, out_channels, kernel_size)
def forward(self, x ):
x = self.convl(x)
return x
5. 模型的保存和加载
保存模型的方式有两种:1.整个模型保存,2.仅仅保存模型的参数。
torch.save(model, './model.pth')#整个保存
torch.save(modle, './model_state.pth')#保存参数
模型的加载也是这两种。
load_model = torch.load('model.pth')
mode.load_state_dic(torch.load('model_state.pth'))
&spm=1001.2101.3001.5002&articleId=106720282&d=1&t=3&u=a5d74225d333455c98591ebc4a064407)
8万+

被折叠的 条评论
为什么被折叠?



