pytorch学习笔记4-张量操作

1. 概览

  1. 张量的合并
  2. 张量的运算
  3. 张量的数值计算
  4. 张量的范数
  5. 张量的高阶操作

2. 张量的合并

# cat
a1 = torch.rand(4,3,32,32)
a2 = torch.rand(5,3,32,32)
torch.cat([a1,a2],dim=0).shape
# torch.Size([9,3,32,32])
a2 = torch.rand(4,1,32,32)
torch.cat([a1,a2],dim=1).shape #[4,4,32,32]
# cat 除了合并的那一个维度不一样其他的都必须相同 否则就会报错

# stack 
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape   # [2,32,8]
# stack 会创建一个新的维度 
# 两个班级 32名学生的8门课成绩 合并回新加一个维度,将两个班的同学区分开
# stack的两个维度必须完全一样 

# split
# split 可以根据单元长度来拆分 也可以根据数量进行拆分
c = torch.rand([2,32,8])
aa, bb = torch.split([1,1],dim=0)
aa.shape,bb.shape   # [1,32,8], [1,32,8]
aa,bb = c.split(1,dim=0)
aa.shape,bb.shape   # [1,32,8], [1,32,8]

# chunk
aa, bb = c.chunk(2,dim=0)
aa.shape,bb.shape   # [1,32,8], [1,32,8]

3. 张量的运算

# 张量的加减乘除
a = torch.rand(3,4)
b = torch.rand(4)
a+b
torch.add(a,b)
torch.all(torch.eq(a-b,torch.sub(a,b))) # tensor(True)
torch.all(torch.eq(a*b,torch.mul(a,b))) # tensor(True)
torch.all(torch.eq(a/b,torch.div(a,b))) # tensor(True)

# 矩阵乘法
a = torch.full([2,2],3)
b = torch.ones(2,2)
torch.mm(a,b) 
torch.matmul(a,b)  #这是矩阵乘法
a@b  # 这是numpy中matmul的重载写法

a = torch.rand(4,784)
b = torch.rand(4,784)
w = torch,ranf(512,784)
(x@w.t()).shape  # [4,512]

# 高维度张量的乘法
a = torch.rand(4,3,28,64)
b = torch.rand(4,3,64,32)
torch.matmul(a,b).shape  # [4,3,28,32]
b = torch.rand(4,1,64,32)
torch.matmul(a,b).shape  # [4,3,28,32]

# 张量的幂,指数对数运算
a = torch.full([2,2],3)
a.pow(2) 
a**2
aa = a**2
aa.rsqrt()   # 平方根的倒数
a = torch.exp(torch.ones(2,2))  
# [[2.7183,2.7183],
#  [2.7183,2.7183]]
torch.log(a)  # a 为2.7183 则为自然指数的对数  

4. 张量的数值计算

a = torch.tensor(3.14)
a.floor()  # 天花板函数  3.
a.cell()   # 地板函数    4.
a.trunc()  # 求整数部分  3.
a.frac()   # 求小数部分  0.1400

# round()
a = torch.tensor(3.499)
a.round()  # 3.
b = torch.tensor(3.5)
a.round()  # 4.

# clamp 裁剪  经常用于梯度裁剪
grad = torch.rand(2,3)*15
grad.max(),grad.median(),grad.clamp(10)
# 相当于求二维矩阵的最大元素 中位数的值 
# glamp(10) 表示最小为10  小于10 取10 否则取原值
grad.clamp(0,10) # 则将取值限制在0-10之间

5. 张量的范数

a = torch.full([8],1.)
b = a.view(2,4)
c = a.view(2,2,2)
a.norm(1),b.norm(1),c.norm(1)  # 8. 8. 8.
a.norm(2),b.norm(2),c.norm(2)  # 2.8284 2.8284 2.8284
b.norm(1.dim=1)  # [4.,4.]
b.norm(2,dim=1)  # [2.,2.]
c.norm(1,dim=1)  # [[2.,2.],[2.,2.]]
c.norm(2,dim=0)  # [[1.1412,1.1412],[1.1412,1.1412]]
a = torch.arange(8).view(2,4).float()
a.min(),a.max(),a.mean(),a.prod()  # 0. 7. 3.5000 0.
a.sum()
a.argmax(),a.argmin()  
# argmax 返回的是最大值的索引 而默认把你所有的数据都打平 然后求出罪之所在的索引 如果想保留原来的shape
a = torch.randn(4,10)
a.argmax(dim=1)  # 还可以求出某个维度中的最大值所在的索引
a.max(dim=1)  # 第一个维度上的最大值以及他所在的索引
a.max(dim=1,keepdim=True)  # 每个维度上可能性最大值所在的索引和他的置信度

a.topk(3,dim=1)  # 第一个维度上的top3
# topk 比之前的max可以返回更多的数据
a.topk(3,dim=1,largest=False)
# kth value 第k个value 表示第k小的数值
a.kthvalue(8,dim=1)
# 第8小意思是第3大 第10小即最大

# a>0 torch.get(a,0)  比较的类型即为byteTensor
a = torch.ones(2,3)
b = torch.randn(2,3)
torch.eq(a,b)
torch.equal(a,a)   # True

6. 张量的高阶操作

where 和 gather操作

# torch.where(condition,x,y) 如果满足条件取x矩阵的对应值 否则取y矩阵的对应值
cond = torch.tensor([0.6769,0.7271],[0.8884,0.4363]])
a = torch.zeros(2,2)
b = torch.ones(2,2)
torch.where(cond>0.5,a,b)

# torch.gather(input,dim,index,out=None) -> Tensor
# gather 用于查表操作
# [dog,cat,whale]  -> [0, 1, 2]
# [1,, 0, 1, 2] -> [cat, dog, cat, whale] 做的就是对应这个事
prob = torch.randn(4,10)
idx = prob.topk(dim=1,k=3)
idx = idx[1]
label = torch.arange(10)+100
torch.gather(label.expand(4,10),dim=1,index=idx.long())
idx
# 此时的后两行就是那种对应关系
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值