why pytorch?
- 计算图分为 动态图(pytorch)&静态图(tensorflow),可用利用python建图,而不需用如 tf专门的建图语言;动态图修改更灵活,一般而言也省内存
配置
- 查看 GPU是否可用:torch.cuda.is_available();数量: torch.cuda.device_count()
基本操作
- 数据类型的定义与运算(注意 max/sum、item、view、与Numpy的转换、cuda)
建模流程
以random data为例构建一个两层NN,准备完数据后的基本步骤如下:
- 写模型class, including init+forward 俩func;实例化
- 选loss_fn 和 optimizer
- 开始训练:
- forward pass
- optim.zero_grad()
- loss.backward()
- optim.step()
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N, D_in, H, D_out = 64, 1000, 100, 10
X = torch.randn(N, D_in, device=device) #放GPU上,本机更慢,可能是交互的原因;
y = torch.randn(N, D_out, device=device)
class TwoLayerNN(nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNN, self).__init__()
# define the model architecture
self.linear1 = nn.Linear(D_in, H, bias=True)
self.linear2 = nn.Linear(H, D_out, bias=True)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
model = TwoLayerNN(D_in, H, D_out)
loss_fn = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) #Adma一般设-3 -4,高级优化器一般刚开始调大一点,后面会自动优化
EPOCH = 5000
a = time()
for i in range(EPOCH):
# Forword pass
y_pred = model(X)
# Compute loss(MSE)
loss = loss_fn(y_pred, y)
# Backward pass,compute gradient;
optimizer.zero_grad()
loss.backward()
# update parameters
optimizer.step()
Dataloader
大致用法:定义一个class,必须要有下面三个func,实例化使用时具体的参数可以看doc
from torch.utils.data import DataLoader, Dataset
class CorpusDataset(Dataset):
def __init__(self, root_dir, file): #读取文件
with open(root_dir+file, 'r', encoding='utf8') as f:
self.data = f.readlines()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):#读取图像
return self.data[idx]
dataset = CorpusDataset(root_dir='./dataset/', file='corpus_havestp.txt')
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
dataloader
Torchtext
tips:
with torch.no_grad: disables tracking of gradients in autograd.
model.eval(): changes the forward() behaviour of the module it is called upon. eg, it disables dropout and has batch norm use the entire population statistics