从Segmentation到Classification任务:数据整理

本文详细描述了作者如何将无人机图像数据集进行预处理,包括图像裁剪以生成适合分类任务的patch,使用sge_clas_binary函数对segmentation数据库按类别分离,并整理成train、val和test数据集的过程。
摘要由CSDN通过智能技术生成

1 前言

这篇笔记记载自己如何把segmentation数据库整理成适合作classification任务的数据库, 万事开头难,好记忆不如烂笔头~

2. 背景介绍

我本来想尝试图像分割(image segmentation task),老大觉得还我还需要打基础,建议从更简单的图像分类(classification task)先做起。 我就,当然听老大的。

2.1 数据库

提前收集了无人机数据集,图像尺寸是(44592,3072,3)。由于样本数量有限,我将图片进行了裁剪。

2.2 数据组织

所收集的数据保存格式如下,其中****.jpg 是图像,****_m.png是对应的mask。

------seg
-------6627.jpg
-------6627_m.png
我的数据混在了一起,其实不太好,后面会尝试把image 和 mask 分开,然后在处理。

3. 上代码

3.1 图像裁剪

此处,在split_images_in_folder 函数中遍历 source_folder中所有的图像,并利用split_image 实现图片的裁剪。

from typing import Tuple
import os
import shutil
from PIL import Image
import numpy as np

def split_images_in_folder(source_folder: str, destination_folder: str, patch_size: Tuple[int, int]):
    # Use assert to check if the source folder exists
    assert os.path.exists(source_folder), f"The source folder '{source_folder}' does not exist."

    # Check if the destination folder exists, and create it if not
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    # Loop through all files in the source folder
    for filename in os.listdir(source_folder):
        if filename.endswith(".jpg"):  # Assuming your images are in JPG format
            # Full path to the image
            image_path = os.path.join(source_folder, filename)

            # Call the split_image function for each image
            split_image(image_path, destination_folder, patch_size)

此处的 image_path, destination_folder, patch_size都是从split_images_in_folder中传入的。

# Function to split a single image
def split_image(image_path: str, destination_folder: str, patch_size: Tuple[int, int]):
    # Use assert to check if the image file exists
    assert os.path.exists(image_path), f"The image file '{image_path}' does not exist."

    # Check if the destination folder exists, and create it if not
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    # Load the image
    image = Image.open(image_path).convert("RGB")

    # Convert the image to a NumPy array
    image_np = np.array(image)

    # Get the dimensions of the original image
    height, width, _ = image_np.shape

    # Calculate the number of patches in each dimension
    num_patches_height = height // patch_size[0]
    num_patches_width = width // patch_size[1]

    # Loop through each patch
    for i in range(num_patches_height):
        for j in range(num_patches_width):
            # Calculate the coordinates of the patch
            start_height = i * patch_size[0]
            end_height = (i + 1) * patch_size[0]
            start_width = j * patch_size[1]
            end_width = (j + 1) * patch_size[1]

            # Extract the patch from the image
            image_patch = image_np[start_height:end_height, start_width:end_width, :]

            # Save the patch to the destination folder
            patch_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_patch_{i}_{j}.jpg"
            patch_path = os.path.join(destination_folder, patch_filename)

            # Save the image patch
            Image.fromarray(image_patch).save(patch_path)

调用:

patch_size = (512, 512)
source_folder = '/home/~/seg'
destination_folder = '/home/~/destination_folder_patches'

split_images_in_folder(source_folder, destination_folder, patch_size)

3.2 分类任务的训练数据集准备

sge_clas_binary 用来将segmentation数据库中的图片按照二分类任务保存到新的数据中。可以按照需要选择实例的种类

"""this is for a binary classification task"""
from typing import Union, List
import os
import shutil
from PIL import Image
import numpy as np


def sge_clas_binary(source_folder: str,destination_folder:str, no_instance_folder: str, instance_labels_to_check:Union[int,List[int]]):
    
    # Use assert to check if the file exists
    assert os.path.exists(source_folder), f"The dataset file '{source_folder}' does not exist."

    # Check if the destination folder exists, and delete its contents if it does
    if os.path.exists(destination_folder):
        shutil.rmtree(destination_folder)
    # Check if the no-instance folder exists, and delete its contents if it does
    if os.path.exists(no_instance_folder):
        shutil.rmtree(no_instance_folder)


    # Check if the destination folder exists, and create it if not
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)
        # Check if the no-instance folder exists, and create it if not
    if not os.path.exists(no_instance_folder):
        os.makedirs(no_instance_folder)
    
    # Loop through all files in the source folder
    for filename in os.listdir(source_folder):
        if filename.endswith(".jpg"):  # Assuming your images are in JPG format
            # Load the image and corresponding mask
            image_path = os.path.join(source_folder, filename)
            mask_path = os.path.join(source_folder, filename.replace(".jpg", "_m.png"))  # Adjust this based on your file naming convention

            image = Image.open(image_path).convert("RGB")
            mask = Image.open(mask_path).convert("RGB")

            # Convert images to NumPy arrays
            image_np = np.array(image)
            mask_np = np.array(mask)

            
            # Check if any instance label is present in the mask
            if isinstance(instance_labels_to_check, int):
                contains_instance = instance_labels_to_check in mask_np
            elif isinstance(instance_labels_to_check, list):
                contains_instance = any(label in mask_np for label in instance_labels_to_check)
            else:
                raise ValueError("Invalid input format. Please provide either a single whole number or a list of whole numbers.")              
            
            if contains_instance:
                # Copy the image to the destination folder
                destination_path = os.path.join(destination_folder, filename)
                shutil.copyfile(image_path, destination_path)
            else:
                # Copy the image to the no-instance folder
                no_instance_path = os.path.join(no_instance_folder, filename)
                shutil.copyfile(image_path, no_instance_path)

调用方式

sge_clas_binary('/home/~/seg', '/home/~/y_folder','/home/~/no_folder' ,[1,4])
# or 
sge_clas_binary('/home/~/seg', '/home/~/y_folder','/home/~/no_folder' ,[4])

3.3 文件夹合并

import os
import shutil

def combine_and_copy_folders(folder_paths, destination_folder):
    # Create the destination folder if it doesn't exist
    os.makedirs(destination_folder, exist_ok=True)

    # Loop through each source folder
    for folder_path in folder_paths:
        # Loop through files in the source folder
        for filename in os.listdir(folder_path):
            source_file_path = os.path.join(folder_path, filename)
            destination_file_path = os.path.join(destination_folder, filename)

            # Copy the file to the destination folder
            shutil.copyfile(source_file_path, destination_file_path)

# Example usage
folder_paths = ['/home/~/image', '/home/~/mask']  # Replace with your actual folder paths
destination_folder = '/home/~cla_data/whole'  # Replace with your desired destination folder path

combine_and_copy_folders(folder_paths, destination_folder)

3.4 获取文件夹下数据数量

import os

def count_files_in_folder(folder_path):
    # Ensure the path is a directory
    if not os.path.isdir(folder_path):
        print(f"{folder_path} is not a valid directory.")
        return None

    # List all files in the directory
    files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

    # Count the number of files
    num_files = len(files)

    return num_files

# Example usage
folder_path = '/path/to/your/folder'  # Replace with your actual folder path
num_files = count_files_in_folder(folder_path)

if num_files is not None:
    print(f"The number of files in {folder_path} is: {num_files}")

3.5 train/val/test dataset划分

其实关于数据集的划分有很多讲究,并没有弄得很清楚,发现这篇博文不错,训练集(train set) 验证集(validation set) 测试集(test set)问题综述 后续可以深入学习。
我的代码中:
root_dir 你的目标文件夹
data_dir 你的数据集所在位置
0.8,0.1,0.1是设置的比例,一会有时间我加入子程序中,这样就可以按需分配
加入“ random.shuffle(files)” 这样可以随机划分
我需要学习“10折验证法”,今天会去学习的。

import os
import shutil
import random

# Define the root directory of your dataset
root_dir =  "/home/dxj/code/TORCH/mobilenetv2/multiple_class/waste/dataset" 
data_dir = "/home/dxj/code/TORCH/mobilenetv2/multiple_class/waste/train"

# Define the directory names for train, validation, and test sets
train_dir = os.path.join(root_dir, "train")
val_dir = os.path.join(root_dir, "val")
test_dir = os.path.join(root_dir, "test")

# Create directories for train, validation, and test sets
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Iterate over the classes in your dataset
for class_name in os.listdir(data_dir):
    # print(class_name)
    
    class_dir = os.path.join(data_dir, class_name)
    # print(class_dir)
    
    # Create directories for each class in train, validation, and test sets
    os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)
    
    # List all files in the class directory
    files = os.listdir(class_dir)
    
    # Shuffle the files randomly
    random.shuffle(files)
    
    # Calculate the number of samples for train, validation, and test sets
    num_files = len(files)
    num_train = int(num_files * 0.8)
    num_val = int(num_files * 0.1)
    num_test = num_files - num_train - num_val
    
    # Split the files into train, validation, and test sets
    train_files = files[:num_train]
    val_files = files[num_train:num_train + num_val]
    test_files = files[num_train + num_val:]
    
    # Move files to the corresponding directories
    for file in train_files:
        src = os.path.join(class_dir, file)
        dst = os.path.join(train_dir, class_name, file)
        shutil.move(src, dst)
        
    for file in val_files:
        src = os.path.join(class_dir, file)
        dst = os.path.join(val_dir, class_name, file)
        shutil.move(src, dst)
        
    for file in test_files:
        src = os.path.join(class_dir, file)
        dst = os.path.join(test_dir, class_name, file)
        shutil.move(src, dst)

print("Dataset split completed successfully!")

如果要处理多分类任务,等我整理后在编辑。

总结

  • 11
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值