import torch
a = torch.rand(3,4)
# tensor([[0.0422, 0.9152, 0.5486, 0.8753],
# [0.3918, 0.6056, 0.0634, 0.9498],
# [0.4444, 0.2209, 0.8165, 0.1533]])
b = torch.rand(4)
# tensor([0.4145, 0.2913, 0.1655, 0.3705])
a + b
# tensor([[0.4567, 1.2065, 0.7141, 1.2458],
# [0.8064, 0.8969, 0.2289, 1.3203],
# [0.8589, 0.5122, 0.9819, 0.5238]])
torch.add(a,b)
# tensor([[0.4567, 1.2065, 0.7141, 1.2458],
# [0.8064, 0.8969, 0.2289, 1.3203],
# [0.8589, 0.5122, 0.9819, 0.5238]])
torch.all(torch.eq(a+b,torch.add(a,b))) # 判断是不是每一个位置对应的值都形同
tensor(True)
torch.all(torch.eq(a-b,torch.sub(a,b)))
tensor(True)
torch.all(torch.eq(a*b,torch.mul(a,b)))
tensor(True)
torch.all(torch.eq(a/b,torch.div(a,b)))
tensor(True)
矩阵相乘matmul
注意: * 表示element-wise,相同位置的相乘,
而matmul表示矩阵相乘
a = torch.full([2,2],3)
# tensor([[3., 3.],
# [3., 3.]])
b = torch.ones(2,2)
# tensor([[1., 1.],
# [1., 1.]])
torch.mm(a,b) # 矩阵相乘
#tensor([[6., 6.],
# [6., 6.]])
torch.matmul(a,b) # 矩阵相乘
#tensor([[6., 6.],
# [6., 6.]])
a@b # 矩阵相乘
#tensor([[6., 6.],
# [6., 6.]])
上面三种方式表示矩阵相乘,torch.mm
只能支持2D的矩阵相乘,一般使用torch.matmul
或者a@b
来表示。
a = torch.rand(4,3,26,64)
b = torch.rand(4,3,64,32)
torch.matmul(a,b).shape # 4维矩阵相乘
torch.Size([4, 3, 26, 32])
b = torch.rand(4,1,64,32) # 自动进行Broadcasting
对于上面的4维矩阵相乘,前面的两维保持不变,后面的两维进行相乘。
指数与平方根
a = torch.full([2,2],3)
# tensor([[3, 3],
# [3, 3]])
a.pow(2) # 平方
# tensor([[9, 9],
# [9, 9]])
b = a ** 2 # 平方
# tensor([[9, 9],
# [9, 9]])
b ** (0.5) # 开根号
# tensor([[3., 3.],
# [3., 3.]])
torch.sqrt(b.to(torch.double)) # 开根号
# tensor([[3., 3.],
# [3., 3.]], dtype=torch.float64)
torch.rsqrt(b.to(torch.double)) # 开根号之后求倒数
# tensor([[0.3333, 0.3333],
# [0.3333, 0.3333]], dtype=torch.float64)
这里求平方根的时候为什么不能直接使用b.sqrt()
?如果使用b.sqrt()
会报如下错误:RuntimeError: sqrt_vml_cpu not implemented for 'Long'
。原因是Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。
以e为底的指数函数
a = torch.exp(torch.ones(2,2))
# tensor([[2.7183, 2.7183],
# [2.7183, 2.7183]])
torch.log(a) # 默认以e为底
# tensor([[1., 1.],
# [1., 1.]])
floor、ceil、trunc、frac
a.floor(),a.ceil(),a.trunc(),a.frac()
(tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
- floor:向下取整
- ceil:向上取整
- trunc:裁剪整数部分
- frac:裁剪小数部分
四舍五入
a = torch.tensor(3.499)
a.round()
# tensor(3.)
a = torch.tensor(3.50)
a.round()
# tensor(4.)