版本 : pytorch 1.0.0.dev20181128
一、Pytorch简介
可以在GPU高速运行的numpy替代库
- 一个灵活快速的深度学习框架
- 主要包含 tensor 和 nn.module 两个核心模块
########### Tensor定义 ###########
x1 = torch.empty(2, 3)
x2 = torch.rand(2, 3)
x3 = torch.zeros(2, 3, dtype=torch.long)
x4 = torch.tensor([5.5, 3])
print(x1)
'''
tensor([[0., 0., 0.],
[0., 0., 0.]])
'''
print(x2)
'''
tensor([[0.2616, 0.6785, 0.9963],
[0.3585, 0.5911, 0.1730]])
'''
print(x3)
'''
tensor([[0, 0, 0],
[0, 0, 0]])
'''
print(x4)
'''
tensor([5.5000, 3.0000])
'''
########### Tensor常用属性 ###########
x = torch.empty(2, 3)
print(x)
'''
tensor([[0.0000e+00, 0.0000e+00, 1.6809e+36],
[7.9594e-43, 0.0000e+00, 0.0000e+00]])
'''
print(x.size()) # torch.Size([2, 3]), 本质是tuple
print(x.type()) # torch.FloatTensor
print(x.requires_grad) # False
print(x.grad) # None
print(x.grad_fn) # None
########### Tensor操作举例:加法 ###########
x = torch.ones(2, 3)
y = torch.zeros(2, 3)
# 方式一
print(x + y)
# 方式二
print(torch.add(x, y))
# 方式三
result = torch.empty(2, 3)
torch.add(x, y, out=result)
print(result)
# 方式四:Any operation that mutates a tensor