【week1】张量的操作与线性回归
张量的拼接与切分
torch.cat()
功能:将张量按维度dim进行拼接。
参数:
- tensor:张量序列
- dim:要拼接的维度
torch.stack()
功能:在新创建的维度dim进行拼接
参数:
- tensor:张量序列
- dim:要拼接的维度
注:stack会扩充张量的维度,cat不会改变张量的维度
torch.chunk()
功能:将张量按维度dim进行平均切分
返回值:张量列表
**注:**若不能整除,最后切分出来的张量会小于其他的张量。不能整除的时候,张量的维度是向上取整的。最后切分完的维度相加要等于被切分的张量维度。
参数:
- input:要切分的张量
- chunks:要切分的份数
- dim:要切分的维度
torch.split()
功能:将张量按维度dim进行切分
返回值:张量的列表
参数:
- tensor:要切分的张量
- split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
- dim:要切分的维度
张量索引
torch.index_select()
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
- input:要索引的张量
- dim:要索引的维度
- index:要索引数据的序号
注:index是一个tensor的时候,数据dtype只能是torch.long类型。
torch.masked_select()
功能:在mask中的ture进行索引
返回值:一维张量
参数:
- input:要索引的张量
- mask:与input同形状的布尔类型的张量
注:t = torch.randint(0, 9, size=(3, 3))
mask = t.ge(5)
t_select = torch.masked_select(t, mask)
ge是大于等于5;gt大于;le小于等于;lt小于
张量变换
torch.shape()
功能:变换张量形状
注:当张量在内存中是连续时,新张量与input共享数据内存
参数:
- input:要变换的张量
- shape:新张量的形状
注:shape为-1的时候表示该维度的大小由程序自己计算。
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2))
torch.transpose()
功能:交换两个张量的维度
参数:
- input:要变换的张量
- dim0:要交换的维度
- dim1:要交换的维度
注:常用在图像预处理中,有时候读取是c×h×w,要转为h×w×c,只能相邻维度转换
torch.t()
功能:2维张量转置,对矩阵而言,等价于
torch.transpose(input,0,1)
torch.squeeze()
功能:压缩长度为1的维度(轴)
参数:
- dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除
torch.unsqueeze()
功能:依据dim扩展维度
- dim:扩展的维度
张量的数学运算
一、加减运算
二、对数,指数,幂函数
三、三角函数
torch.add()
功能:逐元素计算input+alpha×other
参数:
- input:第一个张量
- alpha:乘项因子
- other:第二个张量
注:torch.addciv()加法除法
torch.addcmul() 加法乘法
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XCmYvIMK-1604730036592)(F:\XN\学习笔记\深度之眼\1.jpg)]
线性回归
求解步骤:
1.确定模型
Model :y = wx + b
2.选择损失函数
MSE = [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YWN2yOZB-1604730036598)(F:\XN\学习笔记\深度之眼\2.jpg)]
3.求解梯度并更新w,b:
w = w-LR*w.grad
.确定模型
Model :y = wx + b
2.选择损失函数
MSE = [外链图片转存中…(img-YWN2yOZB-1604730036598)]
3.求解梯度并更新w,b:
w = w-LR*w.grad
b = b-LR*w.grad