在深度学习中,我们会频繁的对数据进行操作。这是我们在写代码最基本、入门的一个操作。
首先导入torch包
import torch
弄懂张量、数组、向量、矩阵中的一些基本概念
x1 = torch.arange(12)
print(x1)
输出结果如下:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
x1 =x1.shape
print(x1)
输出结果如下:
torch.Size([12])
也可以用numel()函数返回张量的个数
x1 =x1.numel()
print(x1)
输出结果如下:
12
要改变⼀个张量的形状而不改变元素数量和元素值,可以调⽤reshape函数。
x1 = x1.reshape(3, 4)
print(x1)
输出结果如下:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
关于reshape()函数
我们还可以用这种写法来进行调用reshape函数
x2 = torch.reshape(x1,(3,4))
print(x2)
输出的结果是一样的,但是在使用reshape()函数时候,首先要计算好维度,比如刚刚我们那个的张量长度是12,所以我们reshape()函数的值必须两个相乘等于12,如(1,12),(2,6),(3,4),(4,3)(6,2)(12,1)等等。如果相乘的不相等,就会出现以下错误
x2 = torch.reshape(x1,(2,4))
print(x2)
当然,如果我们想要避免算错,也可以用-1来替换其中一个数字,此时系统会根据你给的一直,计算出另一个值来,如下图
x2 = torch.reshape(x1,(2,-1))
print(x2)
输出结果是:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
两个的数字不能同时为-1,就像解方程一样,a*b=12 ,这个方程永远解不出来。
x2 = torch.reshape(x1,(-1,-1))
print(x2)
----------------------------------------
x3 =torch.zeros((2, 3, 2))
print(x3)
其输出结果是:
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
x3 =torch.ones((2, 3, 4))
print(x3)
输出结果如下:
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
关于random()函数
x4 = torch.randn(3,4)
print(x4)
输出结果是:
tensor([[ 0.6315, 2.0303, 0.0325, -0.0615],
[ 0.4824, 0.6846, -0.7828, 1.5057],
[-0.8369, 0.2498, 0.5857, 1.8156]])
random()函数的延伸版本randint()函数
x4 = torch.randint(3,10,(3,2))
print(x4)
输出结果为:
tensor([[5, 9],
[6, 3],
[4, 5]])
randint(3,10,(3,2))他的意思是输出3到10之间,一个3*2的矩阵
torch.randperm()函数
代码如下图
x5 = torch.randperm(10)
print(x5)
输出结果是:
tensor([9, 1, 5, 8, 4, 7, 6, 3, 2, 0])
返回整数从0到n-1的随机排列。torch.randperm(10)返回0-9之间的随机排列
加减乘除
x = torch.tensor([1.0, 2, 4, 8])
y = torch.tensor([2, 2, 2, 2])
print(x + y)
print(x * y)
print(x / y)
print(x * y)
print(x ** y) # **运算符是幂运输
输出结果是:
tensor([ 3., 4., 6., 10.])
tensor([ 2., 4., 8., 16.])
tensor([0.5000, 1.0000, 2.0000, 4.0000])
tensor([ 2., 4., 8., 16.])
tensor([ 1., 4., 16., 64.])
也可以求幂运输
z = torch.exp(x)
print(z )
tensor([2.7183e+00, 7.3891e+00, 5.4598e+01, 2.9810e+03])
还可以对张量所有的元素求和,使用sum函数就可
print(x)
print(x.sum())
输出结果:
tensor([1., 2., 4., 8.])
tensor(15.)
判断两个张量的元素是否相等
print(x)
print(y)
print(x==y)
输出结果是
tensor([1., 2., 4., 8.])
tensor([2, 2, 2, 2])
tensor([False, True, False, False])
还可以用cat()函数连接两个张量
X = torch.arange(12, dtype=torch.float32).reshape((3,4))
Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
m= torch.cat((X,Y),dim=0)
n= torch.cat((X,Y),dim=1)
print(X)
print(Y)
print(m)
print(n)
输出结果是
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
tensor([[2., 1., 4., 3.],
[1., 2., 3., 4.],
[4., 3., 2., 1.]])
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]])
tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]])
其中dim = 0 是沿着X轴链接,dim =1 是沿着Y轴链接
索引
print(X[-1])
#结果
tensor([ 8., 9., 10., 11.])
当里面参数是-1时 表示是从矩阵的倒数第一行开始
-2是从倒数第二行
print(X[-2])
#结果
tensor([4., 5., 6., 7.])
如果超出了矩阵的X周,就会报以下错误
print(X[-4])
我们还可以截取某几行的,如下代码
print(X[1:3])
#结果
tensor([[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
除读取外,我们还可以通过指定索引来将元素写⼊矩阵。
X[1, 2] = 9
print(X)
#输出结果是
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 9., 7.],
[ 8., 9., 10., 11.]])