02张量的操作及线性回归

本文详细介绍了PyTorch中的张量操作,包括张量的拼接与切分、索引以及变换,如torch.cat()、torch.stack()、torch.index_select()、torch.reshape()等。此外,还探讨了张量的数学运算,如加减乘除、对数、指数和三角函数。最后,文章阐述了线性回归的基本概念和求解步骤。
摘要由CSDN通过智能技术生成

一、张量的操作

1.1 张量拼接与切分

1.1.1 torch.cat()

torch.cat(tensors, dim=0, out=None)

功能: 将张量按维度dim进行拼接

  • tensors: 张量序列
  • dim: 要拼接的维度
t = torch.ones((2, 3))

t_0 = torch.cat([t, t], dim=0)
t_1 = torch.cat([t, t, t], dim=1)

print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0, t_0.shape, t_1, t_1.shape))

在这里插入图片描述

1.1.2 torch.stack()

torch.stack(tensors, dim=0, out=None)

功能: 在新创建的维度dim上进行拼接

  • tensors: 张量序列
  • dim: 要拼接的维度
t = torch.ones((2, 3))

t_stack = torch.stack([t, t], dim=2)

print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

在这里插入图片描述

注意:

  • cat()不会扩张张量的维度,而stack()则会
  • 当dim指定为0,那么已有的维度会后移,如原(2,3)的t,t1=torch.stack([t, t], dim=0),t1维度为(2,2,3)

1.1.3 torch.chunk()

torch.chunk(input, chunks, dim=0)

功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除, 最后一份张量小于其他张量

  • input: 要切分的张量
  • chunks: 要切分的份数
  • dim: 要切分的维度
a = torch.ones((2, 7))  # 7
print(a)
list_of_tensors = torch.chunk(a, dim=1, chunks=3)  # 3

for idx, t in enumerate(list_of_tensors):
    print("第{}个张量:{}, shape is {}".format(idx + 1, t, t.shape))

在这里插入图片描述

1.1.4 torch.split()

torch.split(tensor,split_size_or_sections, dim=0)

功能: 将张量按维度dim进行切分
返回值: 张量列表

  • tensor: 要切分的张量
  • split_size_or_sections: 为int时, 表示每一份的长度; 为list时, 按list元素切分
  • dim : 要切分的维度
t = torch.ones((2, 5))

# list_of_tensors = torch.split(t, 2, dim=1)
# for idx, t in enumerate(list_of_tensors):
#     print("第{}个张量:{}, shape is {}".format(idx + 1, t, t.shape))

list_of_tensors = torch.split(t, [2, 1, 2], dim=1)
for idx, t in enumerate(list_of_tensors):
    print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述

1.2 张量索引

1.2.1 torch.index_select()

torch.index_select(input, dim, index, out=None)

功能: 在维度dim上,按index索引数据
返回值: 依index索引数据拼接的张量
注意事项:index数据类型必须是torch.long

  • input: 要索引的张量
  • dim: 要索引的维度
  • index: 要索引数据的序号
t = torch.randint(0, 9, size=(3, 3))
idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=1, index=idx)
print("t:\n{}\nt_select:\n{}".format(t, t_select))

在这里插入图片描述

1.2.2 torch.masked_select()

torch.masked_select(input, mask, out=None)

功能: 按mask中的True进行索引
返回值: 一维张量

  • input: 要索引的张量
  • mask: 与i
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值