模型训练加速方法

模型训练加速方法

  • 学习率设置

    • lr = 0.00125*num_gpu*samples_per_gpu
      

    数据读取加速

  • data prefetch (Nvidia Apex中提供的解决方案)

      # pip install prefetch_generator
      from torch.utils.data import DataLoader
      from prefetch_generator import BackgroundGenerator
      
      
      # 使用DataLoaderX代替DataLoader
      class DataLoaderX(DataLoader):
          def __iter__(self):
              return BackgroundGenerator(super().__iter__())
    
  • cuda.Steam加速拷贝过程

    """ 该代码是在使用amp半精度计算的条件下:否则加
    if args.fp16:
    	self.mean = self.mean.half()
    	self.std = self.std.half()
    """
    class DataPrefetcher():
        def __init__(self, loader, opt):
            self.loader = iter(loader)
            self.opt = opt
            self.stream = torch.cuda.Stream()
            self.preload()
    
        def preload(self):
            try:
                self.batch = next(self.loader)
            except StopIteration:
                self.batch = None
                return
            with torch.cuda.stream(self.stream):
                for k in self.batch:
                    if k != 'meta':
                        self.batch[k] = self.batch[k].to(device=self.opt.device, non_blocking=True)
    
        def next(self):
            torch.cuda.current_stream().wait_stream(self.stream)
            batch = self.batch
            self.preload()
            return batch
        
        
    class data_prefetcher():
        def __init__(self, loader):
            self.loader = iter(loader)
            self.stream = torch.cuda.Stream()
            self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
            self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.mean = self.mean.half()
            #     self.std = self.std.half()
            self.preload()
    
        def preload(self):
            try:
                self.next_input, self.next_target = next(self.loader)
            except StopIteration:
                self.next_input = None
                self.next_target = None
                return
          with torch.cuda.stream(self.stream):
                self.next_input = self.next_input.cuda(non_blocking=True)
                self.next_target = self.next_target.cuda(non_blocking=True)
                # With Amp, it isn't necessary to manually convert data to half.
                # if args.fp16:
                #     self.next_input = self.next_input.half()
                # else:
                self.next_input = self.next_input.float()
                self.next_input = self.next_input.sub_(self.mean).div_(self.std)
                
        def next(self):
            torch.cuda.current_stream().wait_stream(self.stream)
            input = self.next_input
            target = self.next_target
            self.preload()
            return input, target
        
    # 加入前数据加载:
    for iter_id, batch in enumerate(data_loader):
        if iter_id >= num_iters:
            break
        for k in batch:
            if k != 'meta':
                batch[k] = batch[k].to(device=opt.device, non_blocking=True)
        run_step()
    
    # 加入后加载数据
    prefetcher = DataPrefetcher(data_loader, opt)
    batch = prefetcher.next()
    iter_id = 0
    while batch is not None:
        iter_id += 1
        if iter_id >= num_iters:
            break
        run_step()
        batch = prefetcher.next()
    
  • OCR模型训练tricks

    • 标点符号:在建立数据集的时候,需要将中文的如[,.’ ";:]等标点符号换成英文的,或者反过来,不要有两份一样的,因为目前不论是attention_ocr还是ctc都算是象形文字,所以模型看到中文分号和英文分号,总觉得是同一个东西,所以会分错;

    • 训练集:在建立数据集的时候,因为ctc_loss中有个sequence_length,所以,为了增加数据分布一致性和ctc的效率,最好先对图片对应的文字进行长度排序,比如前面100个样本的label都是小于5的字符串;后面100个都是小于10的字符串;后面100个都是小于15的字符串,等等。

    • batch间独立,batch内相等:在读取数据的时候,同一个batch中因为图片大小需要相同,而如果是全卷积网络,是可以让不同batch之间独立的。所以图片的缩放可以按照batch之间各自决定。比如第一个batch 读取长度小于5的label和图片,将其缩放到100*32;第二个读取长度小于10的label和图片,将其缩放到200**32;

    • 训练集双尾问题:为了数据的平衡性,需要将数据集中出现次数特别少的和出现次数特别多的label的样本删除,保证每个字符的频率都适中;

  • pytorch处理类别不均衡问题

    • # 数据方面
      import torch
      from torch.utils.data.dataset import random_split 
      from torch.utils.data import DataLoader, WeightedRandomSampler
      from collections import Countor
      def load_data(sample):
          train_data = None
          train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])
      
          self.train_classes = [label for _, label in train_set]
          if sample:
              # Need to get weight for every image in the dataset
              class_count = Counter(self.train_classes)
              class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values]) 
              # Can't iterate over class_count because dictionary is unordered
      
              sample_weights = [0] * len(train_set)
              for idx, (image, label) in enumerate(train_set):
                  class_weight = class_weights[label]
                  sample_weights[idx] = class_weight
      
              sampler = WeightedRandomSampler(weights=sample_weights,
                                              num_samples = len(train_set), replacement=True)  
              train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)
          else:
              train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
      
          val_loader = DataLoader(val_set, batch_size=self.batch_size)
          return train_loader, val_loader
      
      # 模型训练loss加权重
      def load_model(self, arch='resnet'):
          if arch == 'resnet':
              self.model = torchvision.models.resnet50(pretrained=True)
              if self.freeze_backbone:
                  for param in self.model.parameters():
                      param.requires_grad = False
              self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)
          elif arch == 'efficient-net':
              self.model = EfficientNet.from_pretrained('efficientnet-b7')
              if self.freeze_backbone:
                  for param in self.model.parameters():
                      param.requires_grad = False
              self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)    
      
          self.model = self.model.to(self.device)
      
          self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) 
      
          if self.loss_weights:
              class_count = Counter(self.train_classes)
              class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])
              # Cant iterate over class_count because dictionary is unordered
              class_weights = class_weights.to(self.device)  
              self.criterion = nn.CrossEntropyLoss(class_weights)
          else:
              self.criterion = nn.CrossEntropyLoss()
      
  • 早期停止

    • #Callbacks
      # Early stopping
      class EarlyStopping:
        def __init__(self, patience=1, delta=0, path='checkpoint.pt'):
          self.patience = patience
          self.delta = delta
          self.path= path
          self.counter = 0
          self.best_score = None
          self.early_stop = False
      
        def __call__(self, val_loss, model):
          if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(model)
          elif val_loss > self.best_score:
            self.counter +=1
            if self.counter >= self.patience:
              self.early_stop = True 
          else:
            self.best_score = val_loss
            self.save_checkpoint(model)
            self.counter = 0      
      
        def save_checkpoint(self, model):
          torch.save(model.state_dict(), self.path)
      
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值