Pytorch学习笔记【6】—基本运算
Pytorch笔记目录:点击进入
文章目录
1. Add
tensor相加有两种方法效果基本一样
# Add
a = torch.rand(3,4)
b = torch.rand(4)
print(a+b)
# or the result is same
print(torch.add(a,b))
out:
tensor([[0.8559, 1.4637, 1.3004, 0.7409],
[0.7110, 1.4589, 1.3426, 0.5950],
[1.0352, 1.1449, 1.6043, 1.5138]])
tensor([[0.8559, 1.4637, 1.3004, 0.7409],
[0.7110, 1.4589, 1.3426, 0.5950],
[1.0352, 1.1449, 1.6043, 1.5138]])
2. Sub
# Sub
print(a-b)
print(torch.sub(a,b))
# 比较两种两种运算方法
print(torch.all(torch.eq(a-b,torch.sub(a,b))))
out:
tensor([[ 0.1167, -0.0386, -0.4151, -0.3182],
[-0.0282, -0.0434, -0.3730, -0.4641],
[ 0.2959, -0.3574, -0.1112, 0.4546]])
tensor([[ 0.1167, -0.0386, -0.4151, -0.3182],
[-0.0282, -0.0434, -0.3730, -0.4641],
[ 0.2959, -0.3574, -0.1112, 0.4546]])
tensor(True)
3. mul and div
# mul
print(torch.all(torch.eq(a*b,torch.mul(a,b))))
# div
print(torch.all(torch.eq(a/b,torch.div(a,b))))
out:
tensor(True)
tensor(True)
4. matmul
# matmul
a = torch.full([2,2],3)
b = torch.ones(2,2)
print(torch.mm(a,b))
print(torch.matmul(a,b))
print(a@b)
out:
tensor([[6., 6.],
[6., 6.]])
tensor([[6., 6.],
[6., 6.]])
tensor([[6., 6.],
[6., 6.]])
5. pow and sqrt
print(a)
print(a.pow(2))
print(a**2)
aa = a**2
print(aa.sqrt())
out:
tensor([[3., 3.],
[3., 3.]])
tensor([[9., 9.],
[9., 9.]])
6. Exp and log
# Exp and log
a = torch.exp(torch.ones(2,2))
print(a)
print(torch.log(a))
out:
tensor([[2.7183, 2.7183],
[2.7183, 2.7183]])
tensor([[1., 1.],
[1., 1.]])
7. Approximation
# Approximation
a = torch.tensor(3.14)
print(a.floor(),a.ceil(),a.trunc(),a.frac())
out:
tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
8. round
a = torch.tensor(3.499)
print(a.round())
a = torch.tensor(3.5)
print(a.round())
out:
tensor(3.)
tensor(4.)
9. clamp
把不在范围内的都剪掉
# clamp
grad = torch.rand(2,3)*15
# print max
print(grad.max())
# print min
print(grad.median())
# 把tensor中小于10的数全部变成10
print(grad.clamp(10))
# 把不在这个范围的数字大的变成10,小的变成2
print(grad.clamp(2,10))
out:
tensor(10.0630)
tensor(3.4961)
tensor([[10.0630, 10.0000, 10.0000],
[10.0000, 10.0000, 10.0000]])
tensor([[10.0000, 3.4961, 2.0000],
[ 2.0000, 5.1788, 8.3820]])