tensor
torch.tensor(data, # 数据 可以使list numpy
dtype=None, # 数据类型 默认与data一致
device=None, # 所在设备 cuda / cpu
requires_grad=False, # 是否需要梯度
pin_memory = False ,# 是否存于锁页内存)
flag = True
if flag:
arr = np.ones((3,3))
print('ndarray的数据类型:',arr.dtype)
t = torch.tensor(arr,device='cuda')
print(t)
张量拼接与切分
torch.cat()
将张量按维度dim 进行拼接 不会扩展张量的维度
tensors 张量序列
dim 要拼接的维度
torch.stack()
在新创建的维度dim 上进行拼接 会扩展张量的维度
tensors 张量序列
dim 要拼接的维度
flag = True
if flag:
t = torch.ones((2,3