在显存不足时,增加batch size的方法

问题:

如何在显存不足的情况下,增加batch-size?

换言之,如何增加batch-size而无需扩大显存?

 

思路:

将batch数据,分为多个mini-batch,对mini-batch计算loss,再求和,进行反向传播。

这样内存只占用mini-batch大小的数据,用时间换空间。

 

pytorch实现:

import torch
from sklearn import metrics
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


# 简单的TextRNN模型
class TextRNN(nn.Module):
    def __init__(self, num_words, num_classes, embedding_dim, hidden_dim, dropout):
        super(TextRNN, self).__init__()
        self.embed = nn.Embedding(num_embeddings=num_words + 1, embedding_dim=embedding_dim, padding_idx=num_words)
        self.encode = nn.GRU(embedding_dim, 200, batch_first=True, bidirectional=True)
        self.mlp = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, num_classes)
        )

    def forward(self, x, masks):
        x = self.embed(x)
        x, _ = self.encode(x)
        x = x.max(1)[0]
        x = self.mlp(x)
        return x


# 一轮训练
# 对每个batch的数据进行切分为几个小mini-batch
# 计算每个mini-batch的loss,进行相加
# 最终在batch上进行反向传播操作
def train_eval(cate, loader, mini_batch_size, model, optimizer, loss_func):
    model.train() if cate == "train" else model.eval()
    preds, labels, loss_sum = [], [], 0.
    for i, data in enumerate(loader):
        mini_loader = DataLoader(list(zip(*data)), batch_size=mini_batch_size)
        loss = 0.
        for j, (inputs, masks, targets) in enumerate(mini_loader):
            y = model(inputs, masks)
            loss = loss_func(y, targets)
            preds.append(y.max(dim=1)[1].data)
            labels.append(targets.data)
            loss_sum += loss.data
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.data
        print(i, loss_sum)
    preds = torch.cat(preds).tolist()
    labels = torch.cat(labels).tolist()
    loss = loss_sum / len(loader)
    acc = metrics.accuracy_score(labels, preds) * 100
    return loss, acc, preds, labels


if __name__ == '__main__':
    # 模型参数
    num_words = 5000
    num_classes = 20
    embedding_dim = 300
    hidden_dim = 200
    dropout = 0.5

    # 数据集参数
    num_samples = 10000
    pad_len = 1000

    # 训练参数
    batch_size = 4096
    mini_batch_size = 64
    lr = 1e-3
    weight_decay = 1e-6

    # 构造测试数据
    inputs = torch.randint(0, num_words + 1, (num_samples, pad_len))
    masks = torch.randint(0, 1, (num_samples, pad_len, 1)).float()
    targets = torch.randint(0, num_classes - 1, (num_samples,))
    word2vec = torch.rand((num_words + 1, embedding_dim)).numpy()
    dataset = list(zip(inputs, masks, targets))
    loader = DataLoader(dataset,
                        batch_size=batch_size,  # loss反向传播的batch
                        shuffle=True)

    # 模型、损失函数、优化器
    model = TextRNN(num_words=num_words, num_classes=num_classes,
                    embedding_dim=embedding_dim, hidden_dim=hidden_dim,
                    dropout=dropout)
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # 开始训练
    for epoch in range(1, 100):
        loss, acc, preds, labels = train_eval("train", loader, mini_batch_size, model, optimizer, loss_func)
        print("-" * 50)
        print(epoch, loss)
已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页