pytorch基本语法学习(三)
经过两个小节的学习,本小姐是基础知识的最后一个,主要是tensor的拼接与拆分,还会简单介绍一些相关的tensor数学运算和统计。
tensor的拼接 cat & stack
a = torch.FloatTensor(4, 3, 28, 32)
b = torch.FloatTensor(9, 3, 28, 32)
c = torch.FloatTensor(9, 3, 28, 32)
d = torch.FloatTensor(4, 5, 28, 32)
print(torch.cat([a,b], dim=0).shape) # torch.Size([13, 3, 28, 32])
print(torch.cat([a,d], dim=1).shape) # torch.Size([4, 8, 28, 32])
print(torch.stack([b,c],dim=0).shape) # torch.Size([2, 9, 3, 28, 32])
print(torch.stack([b,c],dim=1).shape) # torch.Size([9, 2, 3, 28, 32])
print(torch.stack([b,c],dim=2).shape) # torch.Size([9, 3, 2, 28, 32])
print(torch.stack([b,c],dim=3).shape) # torch.Size([9, 3, 28, 2, 32])
tensor的拆分 split & chunk
a = torch.FloatTensor(9, 3, 28, 32)
b = torch.FloatTensor(9, 3, 28, 32)
d = torch.FloatTensor(9, 3, 28, 32)
c = torch.stack([a,b],dim=0)
print(c.shape) # torch.Size([2, 9, 3, 28, 32])
e = torch.stack([a,b,d],dim=0)
print(e.shape) # torch.Size([3, 9, 3, 28, 32])
aa, bb = c.split([1,1],dim=0) # 每份分多少
print(aa.shape) # torch.Size([1, 9, 3, 28, 32])
print(bb.shape) # torch.Size([1, 9, 3, 28, 32])
aa1, bb1 = c.split(1,dim=0)
print(aa1.shape) # torch.Size([1, 9, 3, 28, 32])
print(bb1.shape) # torch.Size([1, 9, 3, 28, 32])
aa2, bb2, cc2 = e.split(1,dim=0)
print(aa2.shape) # torch.Size([1, 9, 3, 28, 32])
print(bb2.shape) # torch.Size([1, 9, 3, 28, 32])
print(cc2.shape) # torch.Size([1, 9, 3, 28, 32])
aa3, bb3 = e.split([1,2],dim=0)
print(aa3.shape) # torch.Size([1, 9, 3, 28, 32])
print(bb3.shape) # torch.Size([2, 9, 3, 28, 32])
aa4, bb4 = e.split([1,2],dim=2)
print(aa4.shape) # torch.Size([3, 9, 1, 28, 32])
print(bb4.shape) # torch.Size([3, 9, 2, 28, 32])
aa5, bb5 = e.chunk(2,dim=0) # 分成几份
print(aa5.shape) # torch.Size([2, 9, 3, 28, 32])
print(bb5.shape) # torch.Size([1, 9, 3, 28, 32])
aa6, bb6, cc6 = e.chunk(3,dim=0)
print(aa6.shape) # torch.Size([1, 9, 3, 28, 32])
print(bb6.shape) # torch.Size([1, 9, 3, 28, 32])
print(cc6.shape) # torch.Size([1, 9, 3, 28, 32])
tensor的数学运算
# 加、减、乘、除
a = torch.ones(2,3,4)
b = torch.ones(4) # 完全可以按照最直接的方式编写,+,-,*,/
print(torch.all(torch.eq(a + b,torch.add(a,b)))) # tensor(True)
print(torch.all(torch.eq(a - b,torch.sub(a,b)))) # tensor(True)
print(torch.all(torch.eq(a * b,torch.mul(a,b)))) # tensor(True)
print(torch.all(torch.eq(a / b,torch.div(a,b)))) # tensor(True)
# 矩阵相乘 建议采用torch.matmul,因为torch.mm only for 2d
a = torch.FloatTensor([[3,3],[3,3]])
print(a.shape) # torch.Size([2,2])
b = torch.ones(2,2)
print(torch.mm(a, b)) # tensor([[6., 6.],
# [6., 6.]])
print(torch.all(torch.eq(torch.mm(a, b),torch.matmul(a,b)))) # tensor(True)
# 矩阵的平方
a = torch.FloatTensor([[3,3],[3,3]])
aa = a**2
print(torch.all(torch.eq(a.pow(2),a**2))) # tensor(True)
print(aa.sqrt()) # 开平方 tensor([[3., 3.],
# [3., 3.]])
print(torch.exp(a)) # e的a次方 tensor([[20.0855, 20.0855],
# [20.0855, 20.0855]])
print(torch.log(a)) # tensor([[1.0986, 1.0986],
# [1.0986, 1.0986]])
b = torch.tensor(3.14)
print(b.floor()) # tensor(3.)
print(b.ceil()) # tensor(4.)
print(b.trunc()) # tensor(3.)
print(b.frac()) # tensor(0.1400)
tensor的统计
a = torch.full([8],1.0)
b = a.view(2,4)
c = a.view(2,2,2)
print(b) # tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.]])
print(c) # tensor([[[1., 1.],
# [1., 1.]],
# [[1., 1.],
# [1., 1.]]])
print(a.norm(1)) # tensor(8.)
print(b.norm(1)) # tensor(8.)
print(c.norm(1)) # tensor(8.)
print(a.norm(2)) # tensor(2.8284)
print(b.norm(2)) # tensor(2.8284)
print(c.norm(2)) # tensor(2.8284)
print(b.norm(1, dim=1)) # tensor([4., 4.])
print(b.norm(2, dim=1)) # tensor([2., 2.])
print(c.norm(1, dim=0)) # tensor([[2., 2.],
# [2., 2.]])
print(c.norm(2, dim=0)) # tensor([[1.4142, 1.4142],
# [1.4142, 1.4142]])
a = torch.arange(8.0).view(2,4)
print(a) # tensor([[0., 1., 2., 3.],
# [4., 5., 6., 7.]])
print(a.min()) # tensor(0.)
print(a.max()) # tensor(7.)
print(a.mean()) # tensor(3.5000)
print(a.prod()) # tensor(0.)
print(a.sum()) # tensor(28.)
print(a.argmax()) # tensor(7)
print(a.argmin()) # tensor(0)
截至目前基本的语法已经学习完成了,下一步就需要上深度学习模型了。加油加油…