在显存不足时,增加batch size的方法

问题:

如何在显存不足的情况下,增加batch-size?

换言之,如何增加batch-size而无需扩大显存?

思路:

将batch数据,分为多个mini-batch,对mini-batch计算loss,再求和,进行反向传播。

这样内存只占用mini-batch大小的数据,用时间换空间。

pytorch实现:

import torch
from sklearn import metrics
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


# 简单的TextRNN模型
class TextRNN(nn.Module):
    def __init__(self, num_words, num_classes, embedding_dim, hidden_dim, dropout):
        super(TextRNN, self).__init__()
        self.embed = nn.Embedding(num_embeddings=num_words + 1, embedding_dim=embedding_dim, padding_idx=num_words)
        self.encode = nn.GRU(embedding_dim, 200, batch_first=True, bidirectional=True)
        self.mlp = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, num_classes)
        )

    def forward(self, x, masks):
        x = self.embed(x)
        x, _ = self.encode(x)
        x = x.max(1)[0]
        x = self.mlp(x)
        return x


# 一轮训练
# 对每个batch的数据进行切分为几个小mini-batch
# 计算每个mini-batch的loss,进行相加
# 最终在batch上进行反向传播操作
def train_eval(cate, loader, mini_batch_size, model, optimizer, loss_func):
    model.train() if cate == "train" else model.eval() # 定义模型训练模式
    preds, labels, loss_sum = [], [], 0. # loss_sum只做统计操作,不进行反向传播

    for i, data in enumerate(loader):
        # 加载一批mini-batch数据
        mini_loader = DataLoader(list(zip(*data)), batch_size=mini_batch_size)

        loss = 0. # 计算mini-batch的loss总和,进行反向传播
        for j, (inputs, masks, targets) in enumerate(mini_loader):
            y = model(inputs, masks) # 获取输出
            loss += loss_func(y, targets) # mini-batch求loss总和

            # 只做统计,不进行反向传播
            preds.append(y.max(dim=1)[1].data) # 统计preds
            labels.append(targets.data) # 统计labels
        
        # 对loss反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计loss
        loss_sum += loss.data

    preds = torch.cat(preds).tolist()
    labels = torch.cat(labels).tolist()
    loss = loss_sum / len(loader)
    acc = metrics.accuracy_score(labels, preds) * 100
    return loss, acc, preds, labels


if __name__ == '__main__':
    # 模型参数
    num_words = 5000
    num_classes = 20
    embedding_dim = 300
    hidden_dim = 200
    dropout = 0.5

    # 数据集参数
    num_samples = 10000
    pad_len = 1000

    # 训练参数
    batch_size = 4096
    mini_batch_size = 64
    lr = 1e-3
    weight_decay = 1e-6

    # 构造测试数据
    inputs = torch.randint(0, num_words + 1, (num_samples, pad_len))
    masks = torch.randint(0, 1, (num_samples, pad_len, 1)).float()
    targets = torch.randint(0, num_classes - 1, (num_samples,))
    word2vec = torch.rand((num_words + 1, embedding_dim)).numpy()
    dataset = list(zip(inputs, masks, targets))
    loader = DataLoader(dataset,
                        batch_size=batch_size,  # loss反向传播的batch
                        shuffle=True)

    # 模型、损失函数、优化器
    model = TextRNN(num_words=num_words, num_classes=num_classes,
                    embedding_dim=embedding_dim, hidden_dim=hidden_dim,
                    dropout=dropout)
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # 开始训练
    for epoch in range(1, 100):
        loss, acc, preds, labels = train_eval("train", loader, mini_batch_size, model, optimizer, loss_func)
        print("-" * 50)
        print(epoch, loss)
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值