numpy 与 tensor对比

类型(Types)

NumpyPyTorch
np.ndarraytorch.Tensor
np.float32torch.float32; torch.float
np.float64torch.float64; torch.double
np.float16torch.float16; torch.half
np.int8torch.int8
np.uint8torch.uint8
np.int16torch.int16; torch.short
np.int32torch.int32; torch.int
np.int64torch.int64; torch.long

构造器(Constructor)

零和一(Ones and zeros)

NumpyPyTorch
np.empty((2, 3))torch.empty(2, 3)
np.empty_like(x)torch.empty_like(x)
np.eyetorch.eye
np.identitytorch.eye
np.onestorch.ones
np.ones_liketorch.ones_like
np.zerostorch.zeros
np.zeros_liketorch.zeros_like

从已知数据构造

NumpyPyTorch
np.array([[1, 2], [3, 4]])torch.tensor([[1, 2], [3, 4]])
np.array([3.2, 4.3], dtype=np.float16)np.float16([3.2, 4.3])torch.tensor([3.2, 4.3], dtype=torch.float16)
x.copy()x.clone()
np.fromfile(file)torch.tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.loadtorch.load
np.loadtxt
np.concatenatetorch.cat

数值范围

NumpyPyTorch
np.arange(10)torch.arange(10)
np.arange(2, 3, 0.1)torch.arange(2, 3, 0.1)
np.linspacetorch.linspace
np.logspacetorch.logspace

构造矩阵

NumpyPyTorch
np.diagtorch.diag
np.triltorch.tril
np.triutorch.triu

参数

NumpyPyTorch
x.shapex.shape
x.stridesx.stride()
x.ndimx.dim()
x.datax.data
x.sizex.nelement()
x.dtypex.dtype

索引

NumpyPyTorch
x[0]x[0]
x[:, 0]x[:, 0]
x[indices]x[indices]
np.take(x, indices)torch.take(x, torch.LongTensor(indices))
x[x != 0]x[x != 0]

形状(Shape)变换

NumpyPyTorch
x.reshapex.reshape; x.view
x.resize()x.resize_
x.resize_as_
x.transposex.transpose or x.permute
x.flattenx.view(-1)
x.squeeze()x.squeeze()
x[:, np.newaxis]; np.expand_dims(x, 1)x.unsqueeze(1)

数据选择

NumpyPyTorch
np.put
x.putx.put_
x = np.array([1, 2, 3])x.repeat(2) # [1, 1, 2, 2, 3, 3]x = torch.tensor([1, 2, 3])x.repeat(2) # [1, 2, 3, 1, 2, 3]x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3]
np.tile(x, (3, 2))x.repeat(3, 2)
np.choose
np.sortsorted, indices = torch.sort(x, [dim])
np.argsortsorted, indices = torch.sort(x, [dim])
np.nonzerotorch.nonzero
np.wheretorch.where
x[::-1]

数值计算

NumpyPyTorch
x.minx.min
x.argminx.argmin
x.maxx.max
x.argmaxx.argmax
x.clipx.clamp
x.roundx.round
np.floor(x)torch.floor(x); x.floor()
np.ceil(x)torch.ceil(x); x.ceil()
x.tracex.trace
x.sumx.sum
x.cumsumx.cumsum
x.meanx.mean
x.stdx.std
x.prodx.prod
x.cumprodx.cumprod
x.all(x == 1).sum() == x.nelement()
x.any(x == 1).sum() > 0

数值比较

NumpyPyTorch
np.lessx.lt
np.less_equalx.le
np.greaterx.gt
np.greater_equalx.ge
np.equalx.eq
np.not_equalx.ne

pytorch与tensorflow API速查表

方法名称pytrochtensorflownumpy
裁剪torch.clamp(x, min, max)tf.clip_by_value(x, min, max)np.clip(x, min, max)
取最小值torch.min(x, dim)[0]tf.min(x, axis)np.min(x , axis)
取两个tensor的最大值torch.max(x, y)tf.maximum(x, y)np.maximum(x, y)
取两个tensor的最小值torch.min(x, y)torch.minimum(x, y)np.minmum(x, y)
取最大值索引torch.max(x, dim)[1]tf.argmax(x, axis)np.argmax(x, axis)
取最小值索引torch.min(x, dim)[1]tf.argmin(x, axis)np.argmin(x, axis)
比较(x > y)torch.gt(x, y)tf.greater(x, y)np.greater(x, y)
比较(x < y)torch.le(x, y)tf.less(x, y)np.less(x, y)
比较(x==y)torch.eq(x, y)tf.equal(x, y)np.equal(x, y)
比较(x!=y)torch.ne(x, y)tf.not_equal(x, y)np.not_queal(x , y)
取符合条件值的索引torch.nonzero(cond)tf.where(cond)np.where(cond)
多个tensor聚合torch.cat([x, y], dim)tf.concat([x,y], axis)np.concatenate([x,y], axis)
堆叠成一个tensortorch.stack([x1, x2], dim)tf.stack([x1, x2], axis)np.stack([x, y], axis)
tensor切成多个tensortorch.split(x1, split_size_or_sections, dim)tf.split(x1, num_or_size_splits, axis)np.split(x1, indices_or_sections, axis)
-torch.unbind(x1, dim)tf.unstack(x1,axis)NULL
随机扰乱torch.randperm(n) 1tf.random_shuffle(x)np.random.shuffle(x) 2 np.random.permutation(x ) 3
前k个值torch.topk(x, n, sorted, dim)tf.nn.top_k(x, n, sorted)NULL
  1. 该方法只能对0~n-1自然数随机扰乱,所以先对索引随机扰乱,然后再根据扰乱后的索引取相应的数据得到扰乱后的数据
  2. 该方法会修改原值,没有返回值
  3. 该方法不会修改原值,返回扰乱后的值
  • 13
    点赞
  • 75
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值