数据集处理方法

1、摘要:

最近在大数据集训练过程中,遇到数据集各类别的样本数差别较大的问题,现在说一说使用一些方法来处理这种类别不平衡的情况。以下是几种处理类别不平衡的常见方法:1. 欠采样:随机删除数量较多的类别的样本,使得各个类别的样本数接近平衡。但欠采样可能会导致信息丢失,并且在样本较少的类别中可能会丢失重要信息。2. 过采样:刚好相反,这种方法将数量较少的类别的样本复制或生成新的合成样本,以增加其样本数量,使得各个类别的样本数接近平衡。3. 加权损失函数:为不同类别赋予不同的损失权重,使得模型在训练时更关注样本较少的类别,以提高模型对少数类别的学习能力。4. 样本生成:通过一些数据增强技术,如旋转、缩放、平移、裁剪等,生成新的样本来增加样本数量。5. 结合采样和加权:可以结合欠采样和过采样的方法,并使用加权损失函数来处理类别不平衡问题。我们可以考虑使用上述方法之一来处理类别不平衡。

2、具体实现:

过采样:我们可以使用resample()函数进行过采样处理,将数量较少的类别的样本复制并添加到训练集中。通过设置n_samples()参数来控制过采样后的样本数量。然后,使用train_test_split()函数将过采样后的样本划分为训练集和验证集。最后,将文件复制到对应的目录中。



# 获取所有图像路径
image_list = glob.glob('data1/*/*.*')

# 设置保存路径
file_dir = 'data'
if os.path.exists(file_dir):
    shutil.rmtree(file_dir)
os.makedirs(file_dir)

# 进行过采样处理
resampled_files = []
for class_dir in set([os.path.dirname(file) for file in image_list]):
    # 获取当前类别的样本
    class_files = [file for file in image_list if os.path.dirname(file) == class_dir]
    # 对当前类别进行过采样
    resampled_class_files = resample(class_files, replace=True, n_samples=len(class_files)*5, random_state=42)
    # 将过采样后的样本添加到列表中
    resampled_files.extend(resampled_class_files)

# 划分训练集和验证集
trainval_files, val_files = train_test_split(resampled_files, test_size=0.3, random_state=42)

# 拷贝文件到对应的目录
train_dir = 'train'
val_dir = 'val'
train_root = os.path.join(file_dir, train_dir)
val_root = os.path.join(file_dir, val_dir)

# 复制训练集文件
for file in trainval_files:
    file_class = os.path.dirname(file).replace("\\", "/").split('/')[-1]
    file_name = os.path.basename(file)
    file_class = os.path.join(train_root, file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, os.path.join(file_class, file_name))

# 复制验证集文件
for file in val_files:
    file_class = os.path.dirname(file).replace("\\", "/").split('/')[-1]
    file_name = os.path.basename(file)
    file_class = os.path.join(val_root, file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, os.path.join(file_class, file_name))

欠采样:与过采样类似,也是使用resample()函数,可以指定对某种类别的样本进行欠采样。

# 进行欠采样处理
undersampled_files = []
for class_dir in set([os.path.dirname(file) for file in image_list]):
    # 获取当前类别的样本
    class_files = [file for file in image_list if os.path.dirname(file) == class_dir]
    # 对当前类别进行欠采样
    undersampled_class_files = resample(class_files, replace=False, n_samples=int(len(class_files)*0.5), random_state=42)
    # 将欠采样后的样本添加到列表中
    undersampled_files.extend(undersampled_class_files)

数据增强:我经常使用的是随机旋转、高斯模糊、颜色抖动等方式

transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor()

定义加权损失函数:例如以下实例:

class WeightedLoss(nn.Module):
    def __init__(self, weights):
        super(WeightedLoss, self).__init__()
        self.weights = weights

    def forward(self, inputs, targets):
        # 将targets转换为one-hot编码
        targets_onehot = torch.zeros_like(inputs)
        targets_onehot.scatter_(1, targets.view(-1, 1), 1)
        
        # 计算每个样本的损失
        loss_per_sample = -torch.log(torch.sum(inputs * targets_onehot, dim=1))
        
        # 根据类别为损失乘以权重
        weights = torch.Tensor(self.weights).to(inputs.device)
        weighted_loss = torch.mean(weights * loss_per_sample)
        
        return weighted_loss

# 初始化类别权重
class_weights = [1.0, 2.0, 0.5]  # 例如,给第一个类别权重1.0,第二个类别权重2.0,第三个类别权重0.5

接下来使用它:

# 使用加权损失函数
loss_function = WeightedLoss(class_weights)
loss = loss_function(inputs, targets)

以上为全部内容!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值