1、和Numpy交互
(1)数据转换
import torch
import numpy as np
a = np.array([[1,2], [3,4]]) # 创建一个np array
b = torch.from_numpy(a) # 根据np array创建torch 张量
c = b.numpy() # 根据张量, 导出np array
(2)Torch中的运算
- API:
torch
中tensor
的运算和numpy array
运算很相似,比如np.abs() --> torch.abs()
np.sin() --> torch.sin()等
- 矩阵相乘:
data = [[1,2], [3,4]]
tensor = torch.FloatTensor(data) # 转换成32位浮点 tensor
torch.mm(tensor, tensor) # 张量乘法,即矩阵乘法matric × matric
2、变量Variable
(1)什么是Variable与Variable的组成
- 在
Torch
中的Variable
由三部分组成:
data
部分是Torch
的 Tensor
;
grad
部分是这个Variable
的梯度缓存区;
creator
部分是这个Variable
的创造节点
- 如果用一个
Variable
进行计算, 那返回的也是一个同类型的Variable
.
(2)使用
对于式子: y= wx+b
- 导入包
import torch
from torch.autograd import Variable # torch 中 Variable 模块
(3) Variable里面的数据
- 直接
print(variable)
只会输出Variable
形式的数据,在很多时候是用不了的(比如想要用plt
画图), 所以我们要转换一下, 将它变成tensor
形式. - 获取 tensor 数据:
print(variable.data) # tensor
形式 - 也可以转而numpy数据:
print(variable.data.numpy()) # numpy
形式
x = torch.linspace(-5, 5, 200) # 在指定间隔返回相同间隔的数字,在闭区间-5至5生成200个相同间隔的数
x = Variable(x)
x_np = x.data.numpy() # 换成 numpy array, 出图时用
3、PyTorch中的激活函数
- 导入包:
import torch.nn.functional as F # 激活函数都在这
- 平时要用到的就这几个.
relu, sigmoid, tanh, softplus
- 激活函数的输入输出都是
variable
y_relu = F.relu(x).data.numpy()
y_sigmoid = F.sigmoid(x).data.numpy()
y_tanh = F.tanh(x).data.numpy()
y_softplus = F.softplus(x).data.numpy()
# y_softmax = F.softmax(x)
4、PyTorch中的数据加载器和batch
(1) 生成数据生成并构建Dataset
子类
import torch
import torch.utils.data as Data
torch.manual_seed(1)
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10) # 输入数据
y = torch.linspace(10, 1, 10) # 输出数据
# 打包成TensorDataset对象,成为标准数据集
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
(2) 生成batch
数据
# 创建数据加载器
loader = Data.DataLoader(
dataset=torch_dataset, # TensorDataset类型数据集
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 设置随机洗牌
num_workers=2, # 加载数据的进程个数
)
for epoch in range(3): # 训练3轮
for step, (batch_x, batch_y) in enumerate(loader): # 每一步
# 在这里写训练代码...
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
5、保存和加载模型
- 保存和加载整个网络
# 保存和加载整个模型, 包括: 网络结构, 模型参数等
torch.save(resnet, 'model.pkl')
model = torch.load('model.pkl')
6、加载预训练模型
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
reference:
https://cloud.tencent.com/developer/article/1087096
专知国庆特刊-PyTorch手把手深度学习教程系列01-一文带你入门优雅的PyTorch