一文读懂PyTorch张量基础(附代码)

本文介绍了PyTorch中的Tensor类,它类似于Numpy中的ndarray,它构成了在PyTorch中构建神经网络的基础。

我们已经知道张量到底是什么了,并且知道如何用Numpy的ndarray来表示它们,现在我们看看如何在PyTorch中表示它们。

自从Facebook在2017年初将PyTorch开源以来,它已经在机器学习领域取得了令人瞩目的成绩。它可能没有像TensorFlow那样被广泛采用 --- 它的最初发布时间早于PyTorch一年,背后有Google的支持,并且当神经网络工具迎来新的潮流时,它已经将自己确立为了金牌标准。但PyTorch在研究领域受到了广泛的关注,这种关注大部分来自与Torch本身的关系,以及它的动态计算图。

72b34a02e906eb1903485603701062c847edd353

尽管最近我的注意力都在PyTorch上,但这篇文章并不是PyTorch的教程。它更多地是介绍PyTorch的Tensor类,这与Numpy的ndarray类似。

张量基础

让我们来看一下PyTorch的张量基础知识,从创建张量开始(使用Tensor类):

import torch

# Create a Torch tensor

t = torch.Tensor([[1, 2, 3], [4, 5, 6]])

t

tensor([[ 1., 2., 3.],

[ 4., 5., 6.]])

你可以使用两种方式转置一个张量:

# Transpose

t.t()

# Transpose (via permute)

t.permute(-1,0)

两者都会产生如下输出结果:

tensor([[ 1., 4.],

[ 2., 5.],

[ 3., 6.]])

请注意,两种方式都不会导致原始张量的改变。

用view重新塑造张量:

# Reshape via view
t.view(3,2)

tensor([[ 1., 2.],

[ 3., 4.],

[ 5., 6.]])

另一个例子:

# View again...

t.view(6,1)

tensor([[ 1.],

[ 2.],

[ 3.],

[ 4.],

[ 5.],

[ 6.]])

很明显,Numpy所遵循的数学约定延续到了PyTorch张量中(我具体指的是行和列的标记符号)。

原文链接

阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭