使用 PyTorch 来实现不平衡数据集的图像分类

作者 | Marek Paulik 

编译 | ronghuaiyang

转自 | AI公园

一个非常简单和容易上手的例子。

对于教程中使用的大多数人工数据集,每个类都有相同数量的数据。然而,在实际应用中,这种情况很少发生。今天,我将给你介绍来自Kaggle的木薯叶分类,并告诉你当类频率有很大差异时该怎么做。

处理类别的不平衡

有两种方法可以解决这个问题。

  • WeightedRandomSampler

  • loss函数中的weight参数

下一步是创建一个有5个方法的CassavaClassifier类:load_data()、load_model()、fit_one_epoch()、val_one_epoch()和fit()。

在load_data()中,将构造一个train和验证数据集,并返回数据加载器以供进一步使用。

在load_model()中定义了体系结构、损失函数和优化器。

fit方法包含一些初始化和对fit_one_epoch()和val_one_epoch()的循环。

早期停止

早期停止类有助于根据验证损失跟踪最佳模型,并保存检查点。

#Callbacks
# Early stopping
class EarlyStopping:
  def __init__(self, patience=1, delta=0, path='checkpoint.pt'):
    self.patience = patience
    self.delta = delta
    self.path= path
    self.counter = 0
    self.best_score = None
    self.early_stop = False

  def __call__(self, val_loss, model):
    if self.best_score is None:
      self.best_score = val_loss
      self.save_checkpoint(model)
    elif val_loss > self.best_score:
      self.counter +=1
      if self.counter >= self.patience:
        self.early_stop = True 
    else:
      self.best_score = val_loss
      self.save_checkpoint(model)
      self.counter = 0      

  def save_checkpoint(self, model):
    torch.save(model.state_dict(), self.path)

Init

我们首先初始化CassavaClassifier类。

class CassavaClassifier():
    def __init__(self, data_dir, num_classes, device, Transform=None, sample=False, loss_weights=False, batch_size=16,
     lr=1e-4, stop_early=True, freeze_backbone=True):
    #############################################################################################################
    # data_dir - directory with images in subfolders, subfolders name are categories
    # Transform - data augmentations
    # sample - if the dataset is imbalanced set to true and RandomWeightedSampler will be used
    # loss_weights - if the dataset is imbalanced set to true and weight parameter will be passed to loss function
    # freeze_backbone - if using pretrained architecture freeze all but the classification layer
    ###############################################################################################################
        self.data_dir = data_dir
        self.num_classes = num_classes
        self.device = device
        self.sample = sample
        self.loss_weights = loss_weights
        self.batch_size = batch_size
        self.lr = lr
        self.stop_early = stop_early
        self.freeze_backbone = freeze_backbone
        self.Transform = Transform

Load Data

训练图像被组织在子文件夹中,子文件夹名称表示图像的类。这是图像分类问题的典型情况,幸运的是,不需要编写自定义数据集类。在这种情况下,可以立即使用torchvision中的ImageFolder。如果你想使用WeightedRandomSampler,你需要为数据集的每个元素指定一个权重。通常,总图像总比上类别数被用作一个权重。

def load_data(self):
    train_full = torchvision.datasets.ImageFolder(self.data_dir, transform=self.Transform)
    train_set, val_set = random_split(train_full, [math.floor(len(train_full)*0.8), math.ceil(len(train_full)*0.2)])

    self.train_classes = [label for _, label in train_set]
    if self.sample:
        # Need to get weight for every image in the dataset
        class_count = Counter(self.train_classes)
        class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values]) 
        # Can't iterate over class_count because dictionary is unordered

        sample_weights = [0] * len(train_set)
        for idx, (image, label) in enumerate(train_set):
            class_weight = class_weights[label]
            sample_weights[idx] = class_weight

        sampler = WeightedRandomSampler(weights=sample_weights,
                                        num_samples = len(train_set), replacement=True)  
        train_loader = DataLoader(train_set, batch_size=self.batch_size, sampler=sampler)
    else:
        train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)

    val_loader = DataLoader(val_set, batch_size=self.batch_size)

    return train_loader, val_loader

Load Model

在该方法中,我使用迁移学习,架构参数从预先训练的resnet50和efficientnet-b7中选择。CrossEntropyLoss和许多其他损失函数都有权重参数。这是一个手动调整参数,用于处理不平衡。在这种情况下,不需要为每个参数定义权重,只需为每个类定义权重。

def load_model(self, arch='resnet'):
    ##############################################################################################################
    # arch - choose the pretrained architecture from resnet or efficientnetb7
    ############################################################################################################## 
    if arch == 'resnet':
        self.model = torchvision.models.resnet50(pretrained=True)
        if self.freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
        self.model.fc = nn.Linear(in_features=self.model.fc.in_features, out_features=self.num_classes)
    elif arch == 'efficient-net':
        self.model = EfficientNet.from_pretrained('efficientnet-b7')
        if self.freeze_backbone:
            for param in self.model.parameters():
                param.requires_grad = False
        self.model._fc = nn.Linear(in_features=self.model._fc.in_features, out_features=self.num_classes)    

    self.model = self.model.to(self.device)

    self.optimizer = torch.optim.Adam(self.model.parameters(), self.lr) 

    if self.loss_weights:
        class_count = Counter(self.train_classes)
        class_weights = torch.Tensor([len(self.train_classes)/c for c in pd.Series(class_count).sort_index().values])
        # Cant iterate over class_count because dictionary is unordered
        class_weights = class_weights.to(self.device)  
        self.criterion = nn.CrossEntropyLoss(class_weights)
    else:
        self.criterion = nn.CrossEntropyLoss() 

Fit One Epoch

这个方法只包含一个经典的训练循环,带有训练损失记录和tqdm进度条。

def fit_one_epoch(self, train_loader, epoch, num_epochs ): 
    step_train = 0

    train_losses = list() # Every epoch check average loss per batch 
    train_acc = list()
    self.model.train()
    for i, (images, targets) in enumerate(tqdm(train_loader)):
        images = images.to(self.device)
        targets = targets.to(self.device)

        logits = self.model(images)
        loss = self.criterion(logits, targets)

        loss.backward()
        self.optimizer.step()

        self.optimizer.zero_grad()

        train_losses.append(loss.item())

        #Calculate running train accuracy
        predictions = torch.argmax(logits, dim=1)
        num_correct = sum(predictions.eq(targets))
        running_train_acc = float(num_correct) / float(images.shape[0])
        train_acc.append(running_train_acc)
        
    train_loss = torch.tensor(train_losses).mean()    
    print(f'Epoch {epoch}/{num_epochs-1}')  
    print(f'Training loss: {train_loss:.2f}')

Validate one epoch

与上面类似,但此方法在验证数据加载器上迭代。在每一个epoch'之后,平均batch损失和准确性被打印出来。

def val_one_epoch(self, val_loader, scaler):
        val_losses = list()
        val_accs = list()
        self.model.eval()
        step_val = 0
        with torch.no_grad():
            for (images, targets) in val_loader:
                images = images.to(self.device)
                targets = targets.to(self.device)

                logits = self.model(images)
                loss = self.criterion(logits, targets)
                val_losses.append(loss.item())      
            
                predictions = torch.argmax(logits, dim=1)
                num_correct = sum(predictions.eq(targets))
                running_val_acc = float(num_correct) / float(images.shape[0])

                val_accs.append(running_val_acc)
            

            self.val_loss = torch.tensor(val_losses).mean()
            val_acc = torch.tensor(val_accs).mean() # Average acc per batch
        
            print(f'Validation loss: {self.val_loss:.2f}')  
            print(f'Validation accuracy: {val_acc:.2f}') 

Fit

Fit方法在训练和验证过程中经历了许多阶段和循环。如果预训练模型的参数在开始时被冻结,那么unfreeze_after定义了整个模型在多少个epoch之后开始训练。在此之前,只训练全连接层(分类器)。

def fit(self, train_loader, val_loader, num_epochs=10, unfreeze_after=5, checkpoint_dir='checkpoint.pt'):
    if self.stop_early:
        early_stopping = EarlyStopping(
        patience=5, 
        path=checkpoint_dir)
  
    for epoch in range(num_epochs):
        if self.freeze_backbone:
            if epoch == unfreeze_after:  # Unfreeze after x epochs
                for param in self.model.parameters():
                    param.requires_grad = True
        self.fit_one_epoch(train_loader, scaler, epoch, num_epochs)
        self.val_one_epoch(val_loader, scaler)
        if self.stop_early:
            early_stopping(self.val_loss, self.model)
            if early_stopping.early_stop:
                print('Early Stopping')
                print(f'Best validation loss: {early_stopping.best_score}')
                break

Run

现在,可以初始化CassavaClassifier类、创建dataloaders、设置模型并运行整个过程了。

Transform = T.Compose(
                    [T.ToTensor(),
                    T.Resize((256, 256)),
                    T.RandomRotation(90),
                    T.RandomHorizontalFlip(p=0.5),
                    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
data_dir = "Data/cassava-disease/train/train"

classifier = CassavaClassifier(data_dir=data_dir, num_classes=5, device=device, sample=True, Transform=Transform)
train_loader, val_loader = classifier.load_data()
classifier.load_model()
classifier.fit(num_epochs=20, unfreeze_after=5, train_loader=train_loader, val_loader=val_loader)

Inference

使用ImageFolder加载测试数据是不可能的,因为显然没有带有类的子文件夹。因此,我创建了一个返回图像和图像id的自定义数据集。随后,加载模型检查点,通过推理循环运行它,并将预测保存到数据帧中。将数据帧导出为CSV并提交结果。

# Inference
model = torchvision.models.resnet50()
#model = EfficientNet.from_name('efficientnet-b7')
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=5)
model = model.to(device)
checkpoint = torch.load('Data/cassava-disease/sampler_checkpoint.pt')
model.load_state_dict(checkpoint)
model.eval()


# Dataset for test data
class Cassava_Test(Dataset):
  def __init__(self, dir, transform=None):
    self.dir = dir
    self.transform = transform

    self.images = os.listdir(self.dir)  

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img = Image.open(os.path.join(self.dir, self.images[idx]))
    return self.transform(img), self.images[idx] 


test_dir = 'Data/cassava-disease/test/test/0'
test_set = Cassava_Test(test_dir, transform=Transform)
test_loader = DataLoader(test_set, batch_size=4)  

# Test loop
sub = pd.DataFrame(columns=['category', 'id'])
id_list = []
pred_list = []

model = model.to(device)

with torch.no_grad():
  for (image, image_id) in test_loader:
    image = image.to(device)

    logits = model(image)
    predicted = list(torch.argmax(logits, 1).cpu().numpy())

    for id in image_id:
      id_list.append(id)
  
    for prediction in predicted:
      pred_list.append(prediction)
sub['category'] = pred_list
sub['id'] = id_list

mapping = {0:'cbb', 1:'cbsd', 2:'cgm', 3:'cmd', 4:'healthy'}

sub['category'] = sub['category'].map(mapping)
sub = sub.sort_values(by='id')

sub.to_csv('Cassava_sub.csv', index=False)

如果在方案中包含WeightedRandomSampler或损失权值,则测试集的精度会提高2%。对于仅仅几行代码来说,这是一个很好的改进。对于这个数据集,我没有看到这两种方法在精度上的巨大差异,但WeightedRandomSampler的表现要好一些。

不同的学习速度、优化器和数据扩展肯定有自己的发展空间。然而,对于这种简单的方法来说,86%的准确率似乎足够好了。

英文原文:https://marekpaulik.medium.com/imbalanced-dataset-image-classification-with-pytorch-6de864982eb1

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步



个人微信
备注:昵称+学校/公司+方向
如果没有备注不拉群!
拉你进AI蜗牛车交流群



  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值