问题:
如何在显存不足的情况下,增加batch-size?
换言之,如何增加batch-size而无需扩大显存?
思路:
将batch数据,分为多个mini-batch,对mini-batch计算loss,再求和,进行反向传播。
这样内存只占用mini-batch大小的数据,用时间换空间。
pytorch实现:
import torch
from sklearn import metrics
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# 简单的TextRNN模型
class TextRNN(nn.Module):
def __init__(self, num_words, num_classes, embedding_dim, hidden_dim, dropout):
super(TextRNN, self).__init__()
self.embed = nn.Embedding(num_embeddings=num_words + 1, embedding_dim=embedding_dim, padding_idx=num_words)
self.encode = nn.GRU(embedding_dim, 200, batch_first=True, bidirectional=True)
self.mlp = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, num_classes)
)
def forward(self, x, masks):
x = self.embed(x)
x, _ = self.encode(x)
x = x.max(1)[0]
x = self.mlp(x)
return x
# 一轮训练
# 对每个batch的数据进行切分为几个小mini-batch
# 计算每个mini-batch的loss,进行相加
# 最终在batch上进行反向传播操作
def train_eval(cate, loader, mini_batch_size, model, optimizer, loss_func):
model.train() if cate == "train" else model.eval() # 定义模型训练模式
preds, labels, loss_sum = [], [], 0. # loss_sum只做统计操作,不进行反向传播
for i, data in enumerate(loader):
# 加载一批mini-batch数据
mini_loader = DataLoader(list(zip(*data)), batch_size=mini_batch_size)
loss = 0. # 计算mini-batch的loss总和,进行反向传播
for j, (inputs, masks, targets) in enumerate(mini_loader):
y = model(inputs, masks) # 获取输出
loss += loss_func(y, targets) # mini-batch求loss总和
# 只做统计,不进行反向传播
preds.append(y.max(dim=1)[1].data) # 统计preds
labels.append(targets.data) # 统计labels
# 对loss反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 统计loss
loss_sum += loss.data
preds = torch.cat(preds).tolist()
labels = torch.cat(labels).tolist()
loss = loss_sum / len(loader)
acc = metrics.accuracy_score(labels, preds) * 100
return loss, acc, preds, labels
if __name__ == '__main__':
# 模型参数
num_words = 5000
num_classes = 20
embedding_dim = 300
hidden_dim = 200
dropout = 0.5
# 数据集参数
num_samples = 10000
pad_len = 1000
# 训练参数
batch_size = 4096
mini_batch_size = 64
lr = 1e-3
weight_decay = 1e-6
# 构造测试数据
inputs = torch.randint(0, num_words + 1, (num_samples, pad_len))
masks = torch.randint(0, 1, (num_samples, pad_len, 1)).float()
targets = torch.randint(0, num_classes - 1, (num_samples,))
word2vec = torch.rand((num_words + 1, embedding_dim)).numpy()
dataset = list(zip(inputs, masks, targets))
loader = DataLoader(dataset,
batch_size=batch_size, # loss反向传播的batch
shuffle=True)
# 模型、损失函数、优化器
model = TextRNN(num_words=num_words, num_classes=num_classes,
embedding_dim=embedding_dim, hidden_dim=hidden_dim,
dropout=dropout)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
# 开始训练
for epoch in range(1, 100):
loss, acc, preds, labels = train_eval("train", loader, mini_batch_size, model, optimizer, loss_func)
print("-" * 50)
print(epoch, loss)