1、摘要:
最近在大数据集训练过程中,遇到数据集各类别的样本数差别较大的问题,现在说一说使用一些方法来处理这种类别不平衡的情况。以下是几种处理类别不平衡的常见方法:1. 欠采样:随机删除数量较多的类别的样本,使得各个类别的样本数接近平衡。但欠采样可能会导致信息丢失,并且在样本较少的类别中可能会丢失重要信息。2. 过采样:刚好相反,这种方法将数量较少的类别的样本复制或生成新的合成样本,以增加其样本数量,使得各个类别的样本数接近平衡。3. 加权损失函数:为不同类别赋予不同的损失权重,使得模型在训练时更关注样本较少的类别,以提高模型对少数类别的学习能力。4. 样本生成:通过一些数据增强技术,如旋转、缩放、平移、裁剪等,生成新的样本来增加样本数量。5. 结合采样和加权:可以结合欠采样和过采样的方法,并使用加权损失函数来处理类别不平衡问题。我们可以考虑使用上述方法之一来处理类别不平衡。
2、具体实现:
过采样:我们可以使用resample()函数进行过采样处理,将数量较少的类别的样本复制并添加到训练集中。通过设置n_samples()参数来控制过采样后的样本数量。然后,使用train_test_split()函数将过采样后的样本划分为训练集和验证集。最后,将文件复制到对应的目录中。
# 获取所有图像路径
image_list = glob.glob('data1/*/*.*')
# 设置保存路径
file_dir = 'data'
if os.path.exists(file_dir):
shutil.rmtree(file_dir)
os.makedirs(file_dir)
# 进行过采样处理
resampled_files = []
for class_dir in set([os.path.dirname(file) for file in image_list]):
# 获取当前类别的样本
class_files = [file for file in image_list if os.path.dirname(file) == class_dir]
# 对当前类别进行过采样
resampled_class_files = resample(class_files, replace=True, n_samples=len(class_files)*5, random_state=42)
# 将过采样后的样本添加到列表中
resampled_files.extend(resampled_class_files)
# 划分训练集和验证集
trainval_files, val_files = train_test_split(resampled_files, test_size=0.3, random_state=42)
# 拷贝文件到对应的目录
train_dir = 'train'
val_dir = 'val'
train_root = os.path.join(file_dir, train_dir)
val_root = os.path.join(file_dir, val_dir)
# 复制训练集文件
for file in trainval_files:
file_class = os.path.dirname(file).replace("\\", "/").split('/')[-1]
file_name = os.path.basename(file)
file_class = os.path.join(train_root, file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, os.path.join(file_class, file_name))
# 复制验证集文件
for file in val_files:
file_class = os.path.dirname(file).replace("\\", "/").split('/')[-1]
file_name = os.path.basename(file)
file_class = os.path.join(val_root, file_class)
if not os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, os.path.join(file_class, file_name))
欠采样:与过采样类似,也是使用resample()函数,可以指定对某种类别的样本进行欠采样。
# 进行欠采样处理
undersampled_files = []
for class_dir in set([os.path.dirname(file) for file in image_list]):
# 获取当前类别的样本
class_files = [file for file in image_list if os.path.dirname(file) == class_dir]
# 对当前类别进行欠采样
undersampled_class_files = resample(class_files, replace=False, n_samples=int(len(class_files)*0.5), random_state=42)
# 将欠采样后的样本添加到列表中
undersampled_files.extend(undersampled_class_files)
数据增强:我经常使用的是随机旋转、高斯模糊、颜色抖动等方式
transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1, 3.0)),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.Resize((224, 224)),
transforms.ToTensor()
定义加权损失函数:例如以下实例:
class WeightedLoss(nn.Module):
def __init__(self, weights):
super(WeightedLoss, self).__init__()
self.weights = weights
def forward(self, inputs, targets):
# 将targets转换为one-hot编码
targets_onehot = torch.zeros_like(inputs)
targets_onehot.scatter_(1, targets.view(-1, 1), 1)
# 计算每个样本的损失
loss_per_sample = -torch.log(torch.sum(inputs * targets_onehot, dim=1))
# 根据类别为损失乘以权重
weights = torch.Tensor(self.weights).to(inputs.device)
weighted_loss = torch.mean(weights * loss_per_sample)
return weighted_loss
# 初始化类别权重
class_weights = [1.0, 2.0, 0.5] # 例如,给第一个类别权重1.0,第二个类别权重2.0,第三个类别权重0.5
接下来使用它:
# 使用加权损失函数
loss_function = WeightedLoss(class_weights)
loss = loss_function(inputs, targets)
以上为全部内容!