四则运算
import torch
a = torch. tensor( [
[ 1 , 2 ] ,
[ 3 , 4 ]
] )
b = torch. tensor( [
[ 10 , 20 ]
] )
print ( "torch.all(torch.eq(a+b, torch.add(a,b))):" ,
torch. all ( torch. eq( a+ b, torch. add( a, b) ) ) )
print ( "a+b:\n{}\n" . format ( a+ b) )
print ( "torch.all(torch.eq(a-b, torch.sub(a,b))):" ,
torch. all ( torch. eq( a- b, torch. sub( a, b) ) ) )
print ( "a*b:\n{}\n" . format ( a- b) )
print ( "torch.all(torch.eq(a*b, torch.mul(a,b))):" ,
torch. all ( torch. eq( a* b, torch. mul( a, b) ) ) )
print ( "a*b:\n{}\n" . format ( a* b) )
print ( "torch.all(torch.eq(a/b, torch.div(a,b))):" ,
torch. all ( torch. eq( a/ b, torch. div( a, b) ) ) )
print ( "a*b:\n{}\n" . format ( a/ b) )
矩阵相乘
import torch
a = torch. tensor( [
[ 1 ] ,
[ 3 ]
] )
b = torch. tensor( [
[ 10 , 20 ]
] )
print ( "torch.mm(a, b):\n{}\n" . format ( torch. mm( a, b) ) )
print ( "torch.matmul(a, b):\n{}\n" . format ( torch. matmul( a, b) ) )
print ( "a@b:\n{}\n" . format ( a@b) )
大于2维的矩阵相乘
import torch
a1 = torch. rand( 4 , 3 , 28 , 64 )
b1 = torch. rand( 4 , 3 , 64 , 32 )
c1 = torch. matmul( a1, b1)
print ( "c1.shape: " , c1. shape)
a2 = torch. rand( 4 , 1 , 28 , 64 )
b2 = torch. rand( 4 , 3 , 64 , 32 )
c2 = torch. matmul( a2, b2)
print ( "c2.shape: " , c2. shape)
幂运算
import torch
a = torch. tensor( [
[ 1 , 2 ] ,
[ 3 , 4 ]
] )
print ( "a.pow(2):\n{}\n" . format ( a. pow ( 2 ) ) )
print ( "a**2:\n{}\n" . format ( a** 2 ) )
print ( "a.pow(0.5):\n{}\n" . format ( a. pow ( 0.5 ) ) )
print ( "a.sqrt():\n{}\n" . format ( a. sqrt( ) ) )
print ( "a.rsqrt():\n{}\n" . format ( a. rsqrt( ) ) )
print ( "a**0.5:\n{}\n" . format ( a** 0.5 ) )
exp log
import torch
a = torch. tensor( [
[ 1 , 2 ] ,
[ 3 , 4 ]
] )
a_exp = torch. exp( a)
print ( "torch.exp(a):\n{}\n" . format ( a_exp) )
print ( "torch.log(a_exp):\n{}\n" . format ( torch. log( a_exp) ) )
近似值
import torch
a = torch. tensor( 1.67 )
print ( "a.floor():" , a. floor( ) )
print ( "a.ceil():" , a. ceil( ) )
print ( "a.trunc():" , a. trunc( ) )
print ( "a.frac():" , a. frac( ) )
print ( "a.round():" , a. round ( ) )
最大值、最小值、中位数
import torch
a = torch. rand( 2 , 3 ) * 20
print ( "a:\n{}\n" . format ( a) )
print ( "a.max(): " , a. max ( ) )
print ( "a.median(): " , a. median( ) )
print ( "a.min(): " , a. min ( ) )
限制区间
import torch
a = torch. rand( 2 , 3 ) * 20
print ( "a:\n{}\n" . format ( a) )
print ( "a.clamp(10):\n{}\n" . format ( a. clamp( 10 ) ) )
print ( "a.clamp(5, 10):\n{}\n" . format ( a. clamp( 5 , 10 ) ) )