学习pytorch中的一些笔记(2)
一、torch.expand(*sizes)
torch.expand()用来扩展某张量某维度的值,先看代码示例:
import torch
x = torch.tensor([[1, 2, 3]])
print(x.size())
y = x.expand(2,3)
print(y)
print(y.size())
输出:
torch.Size([1, 3])
tensor([[1, 2, 3],
[1, 2, 3]])
torch.Size([2, 3])
官网文档如下:
通过官方文档,我觉得需要注意的一点是,expand()只能对维度为1的那个维度进行扩张,如果不是1,则无法进行扩展。示例如下:
import torch
x = torch.tensor([[1,