提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
pytorch基础学习6
一、语法 x = x.view(x.size(0), -1)
文档:view(*args) → Tensor
返回一个有相同数据但大小不同的tensor。 返回的tensor必须有与原tensor相同的数据和相同数目的元素,但可以有不同的大小。一个tensor必须是连续的contiguous()才能被查看。
x.size(0)
输入值是x,x.size是以元组的形式返回X的数据格式,也就是X的维度形状。以图像处理的输入为例,通常第一个维度是batch size。
那么语法可以转换为x = x.view(batch size,-1)
x.view()函数,表示将输入转换为括号里想要输出的tensor数据形状,x.view(y,-1)中y表示返回的是y行,-1表示的是n列(根据实际返回的数据个数,以自适应的方式返回)。
故x = x.view(x.size(0), -1) 的意为将输入tensor x转换为tensor(batch size行,自适应的n列)
类似于numpy.reshape()函数
二、语法 torch.max()
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
1.部分参数:
input:输入值,数据为Tensor格式;
dim:压缩第几个维度,若dim = 0 ,将输入值压缩成一行,若是二维的数据,即可认为求取每列的最大值;dim = 1,即压缩列,求取每行的最大值;
keepdim:保持维度,默认值为False,不保持原始维度;
2.返回值
返回值通常是2个元素的元组。第一个元素是返回输入值的最大值,数据格式是tensor ;第二个元素是返回最大值的索引值,数据格式为tensor。
3.如何将返回值以Python能直接处理的数据类型取出,通常使用函数item()。如:
torch.max([2,3,4]).item()
返回值为4