如何微调SAM模型:从环境配置到训练实现的完整指南

如何微调SAM模型:从环境配置到训练实现的完整指南

补充1:

(2025年1月2日)

很多朋友来问数据标注是什么格式,因此添加补充1作解答。

运行代码末尾提供的demo,既可以生成标注格式的demo示例。

python sam-data-setup.py

数据集目录下,放images文件夹、masks文件夹、和annotations.txt,
在这里插入图片描述

images里放原始图片,这里随机生成的。可在这个文件夹里放入自己的数据。

在这里插入图片描述

images里放对应的掩码图像,并且对应更改文件后缀名,在这个文件夹里放入自己数据对应的标签掩码图像。

在这里插入图片描述

annotations.txt里放图片对应的检测框坐标信息。

在这里插入图片描述

引言

Segment Anything Model (SAM) 是 Meta AI 推出的一个强大的图像分割模型。尽管预训练模型表现优秀,但在特定领域(如医疗影像、工业检测等)可能需要进行微调以获得更好的性能。本文将详细介绍如何微调 SAM 模型,包括环境配置、数据准备和训练实现。

目录

  1. 环境配置
  2. 项目结构
  3. 数据准备
  4. 模型微调
  5. 训练过程
  6. 注意事项和优化建议

1. 环境配置

首先,我们需要配置正确的 Python 环境和依赖包。推荐使用虚拟环境来管理依赖:

# 创建并激活虚拟环境
python -m venv sam_env
# Windows:
.\sam_env\Scripts\activate
# Linux/Mac:
source sam_env/bin/activate

# 安装依赖
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install opencv-python
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install numpy matplotlib

# 下载预训练模型
# Windows PowerShell:
Invoke-WebRequest -Uri "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -OutFile "sam_vit_b_01ec64.pth"
# Linux/Mac:
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

2. 项目结构

推荐的项目结构如下:

project_root/
├── stamps/
│   ├── images/         # 训练图像
│   ├── masks/          # 分割掩码
│   └── annotations.txt # 边界框标注
├── checkpoints/        # 模型检查点
├── setup_sam_data.py   # 数据准备脚本
└── sam_finetune.py     # 训练脚本

3. 数据准备

为了训练模型,我们需要准备以下数据:

  • 训练图像
  • 分割掩码
  • 边界框标注

以下是数据准备脚本的实现:

import os
import numpy as np
import cv2
from pathlib import Path

def create_project_structure():
    """创建项目所需的目录结构"""
    directories = [
        './stamps/images',
        './stamps/masks',
        './checkpoints'
    ]
    
    for dir_path in directories:
        Path(dir_path).mkdir(parents=True, exist_ok=True)
    
    return directories

def create_sample_data(num_samples=5):
    """创建示例训练数据"""
    annotations = []
    
    for i in range(num_samples):
        # 创建示例图像
        image = np.ones((500, 500, 3), dtype=np.uint8) * 255
        center_x = np.random.randint(150, 350)
        center_y = np.random.randint(150, 350)
        radius = np.random.randint(50, 100)
        
        # 绘制对象
        cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
        
        # 创建掩码
        mask = np.zeros((500, 500), dtype=np.uint8)
        cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        
        # 保存文件
        cv2.imwrite(f'./stamps/images/sample_{i}.jpg', image)
        cv2.imwrite(f'./stamps/masks/sample_{i}_mask.png', mask)
        
        # 计算边界框
        x1 = max(0, center_x - radius)
        y1 = max(0, center_y - radius)
        x2 = min(500, center_x + radius)
        y2 = min(500, center_y + radius)
        
        annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
    
    # 保存标注文件
    with open('./stamps/annotations.txt', 'w') as f:
        f.writelines(annotations)

4. 模型微调

4.1 数据集类实现

首先实现自定义数据集类:

class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = ResizeLongestSide(1024)
        
        # 加载标注
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # 加载和预处理图像
        image = cv2.imread(os.path.join(self.image_dir, ann['image']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(self.mask_dir, 
                         ann['image'].replace('.jpg', '_mask.png')), 
                         cv2.IMREAD_GRAYSCALE)
        mask = mask.astype(np.float32) / 255.0
        
        # 图像处理
        original_size = image.shape[:2]
        input_image = self.transform.apply_image(image)
        input_image = input_image.astype(np.float32) / 255.0
        input_image = torch.from_numpy(input_image).permute(2, 0, 1)
        
        # 标准化
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        # 处理边界框和掩码
        bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
        mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
        
        return {
            'image': input_image.float(),
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch
        }

4.2 训练函数实现

训练函数的核心实现:

def train_sam(
    model_type='vit_b',
    checkpoint_path='sam_vit_b_01ec64.pth',
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 初始化模型
    sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
    sam_model.to(device)
    
    # 准备数据和优化器
    dataset = StampDataset(image_dir='./stamps/images',
                          mask_dir='./stamps/masks',
                          bbox_file='./stamps/annotations.txt')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
    loss_fn = torch.nn.MSELoss()
    
    # 训练循环
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # 准备数据
            input_image = batch['image'].to(device)
            original_size = batch['original_size']
            bbox = batch['bbox'].to(device)
            gt_mask = batch['mask'].to(device)
            
            # 前向传播
            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )
            
            # 生成预测
            mask_predictions, _ = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # 后处理
            upscaled_masks = sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size[0]
            ).to(device)
            
            binary_masks = torch.sigmoid(upscaled_masks)
            
            # 计算损失并优化
            loss = loss_fn(binary_masks, gt_mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        # 输出epoch统计
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
        
        # 保存检查点
        if (epoch + 1) % 5 == 0:
            checkpoint_file = f'./checkpoints/sam_finetuned_epoch_{epoch+1}.pth'
            torch.save(sam_model.state_dict(), checkpoint_file)

5. 训练过程

完整的训练过程如下:

  1. 准备环境和数据:
python setup_sam_data.py

在这里插入图片描述

  1. 开始训练:
python sam_finetune.py

在这里插入图片描述

6. 注意事项和优化建议

  1. 数据预处理:

    • 确保图像数据类型正确(float32)
    • 进行适当的数据标准化
    • 注意图像尺寸的一致性
  2. 训练优化:

    • 根据GPU内存调整batch_size
    • 适当调整学习率
    • 考虑使用学习率调度器
    • 添加验证集评估
    • 实现早停机制
  3. 可能的改进:

    • 添加数据增强
    • 使用不同的损失函数
    • 实现多GPU训练
    • 添加训练过程可视化
    • 实现模型验证和测试

7. 模型预测和可视化

在完成模型微调后,我们需要一个方便的方式来使用模型进行预测并可视化结果。以下是完整的实现:

7.1 预测器类实现

首先,我们封装一个预测器类,用于处理模型加载、图像预处理和预测:

class SAMPredictor:
    def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
        self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
        self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam_model.to(self.device)
        self.transform = ResizeLongestSide(1024)

这个类提供了简单的接口来加载模型并进行预测。主要功能包括:

  • 模型加载和设备配置
  • 图像预处理
  • 掩码预测
  • 后处理优化

7.2 可视化函数

为了better展示预测结果,我们实现了一个可视化函数:

def visualize_prediction(image, mask, bbox, confidence, save_path=None):
    plt.figure(figsize=(15, 5))
    # 显示原始图像、预测掩码和叠加结果
    ...

这个函数可以同时显示:

  • 原始图像(带边界框)
  • 预测的分割掩码
  • 结果叠加视图

7.3 使用示例

以下是如何使用这些工具的完整示例:

# 初始化预测器
predictor = SAMPredictor("./checkpoints/sam_finetuned_final.pth")

# 读取测试图像
image = cv2.imread("test_image.jpg")
bbox = [x1, y1, x2, y2]  # 边界框坐标

# 预测
mask, confidence = predictor.predict(image, bbox)

# 可视化
visualize_prediction(image, mask, bbox, confidence, "result.png")

在这里插入图片描述

7.4 注意事项

在使用预测器时,需要注意以下几点:

  1. 输入图像处理:

    • 确保图像格式正确(RGB)
    • 注意图像尺寸的一致性
    • 正确的数据类型和范围
  2. 边界框格式:

    • 使用 [x1, y1, x2, y2] 格式
    • 确保坐标在图像范围内
    • 坐标值为浮点数
  3. 性能优化:

    • 批处理预测
    • GPU 内存管理
    • 结果缓存

7.5 可能的改进

  1. 批量处理功能:
def predict_batch(self, images, bboxes):
    results = []
    for image, bbox in zip(images, bboxes):
        mask, conf = self.predict(image, bbox)
        results.append((mask, conf))
    return results
  1. 多边界框支持:
def predict_multiple_boxes(self, image, bboxes):
    masks = []
    for bbox in bboxes:
        mask, _ = self.predict(image, bbox)
        masks.append(mask)
    return np.stack(masks)
  1. 交互式可视化:
def interactive_visualization(image, predictor):
    def onclick(event):
        if event.button == 1:  # 左键点击
            bbox = [event.xdata-50, event.ydata-50, 
                   event.xdata+50, event.ydata+50]
            mask, _ = predictor.predict(image, bbox)
            visualize_prediction(image, mask, bbox)
    
    fig, ax = plt.subplots()
    ax.imshow(image)
    fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

这些工具和示例可以帮助你更好地理解和使用微调后的SAM模型。根据具体需求,你可以进一步优化和扩展这些功能。

结论

通过以上步骤,我们实现了SAM模型的微调过程。这个实现可以作为基础,根据具体需求进行优化和改进。在实际应用中,可能需要根据具体任务调整数据预处理、损失函数和训练策略。

建议在使用时注意以下几点:

  1. 确保训练数据质量
  2. 合理设置训练参数
  3. 定期保存检查点
  4. 监控训练过程
  5. 适当使用数据增强

希望这个教程对你的项目有所帮助!如果有任何问题,欢迎讨论和交流。

参考资料

  1. Segment Anything 官方仓库
  2. PyTorch 文档
  3. SAM 论文:Segment Anything
  4. torchvision 文档

快速部署:

下载这三个代码,配置好运行环境,依次运行:

# sam-data-setup.py
import os
import numpy as np
import cv2
from pathlib import Path

def create_project_structure():
    """创建项目所需的目录结构"""
    # 创建主目录
    directories = [
        './stamps/images',
        './stamps/masks',
        './checkpoints'
    ]
    
    for dir_path in directories:
        Path(dir_path).mkdir(parents=True, exist_ok=True)
    
    return directories

def create_sample_data(num_samples=5):
    """创建示例训练数据"""
    # 创建示例图像和掩码
    annotations = []
    
    for i in range(num_samples):
        # 创建示例图像 (500x500)
        image = np.ones((500, 500, 3), dtype=np.uint8) * 255
        # 添加一个示例印章 (随机位置的圆形)
        center_x = np.random.randint(150, 350)
        center_y = np.random.randint(150, 350)
        radius = np.random.randint(50, 100)
        
        # 绘制印章
        cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
        
        # 创建对应的掩码
        mask = np.zeros((500, 500), dtype=np.uint8)
        cv2.circle(mask, (center_x, center_y), radius, 255, -1)
        
        # 保存图像和掩码
        image_path = f'./stamps/images/sample_{i}.jpg'
        mask_path = f'./stamps/masks/sample_{i}_mask.png'
        
        cv2.imwrite(image_path, image)
        cv2.imwrite(mask_path, mask)
        
        # 计算边界框
        x1 = max(0, center_x - radius)
        y1 = max(0, center_y - radius)
        x2 = min(500, center_x + radius)
        y2 = min(500, center_y + radius)
        
        # 添加到注释列表
        annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
    
    # 保存注释文件
    with open('./stamps/annotations.txt', 'w') as f:
        f.writelines(annotations)

def main():
    print("开始创建项目结构...")
    directories = create_project_structure()
    for dir_path in directories:
        print(f"创建目录: {dir_path}")
    
    print("\n创建示例训练数据...")
    create_sample_data()
    print("示例数据创建完成!")
    
    print("\n项目结构:")
    for root, dirs, files in os.walk('./stamps'):
        level = root.replace('./stamps', '').count(os.sep)
        indent = ' ' * 4 * level
        print(f"{indent}{os.path.basename(root)}/")
        sub_indent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{sub_indent}{f}")

if __name__ == '__main__':
    main()
# sam_finetune_decoder.py
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
import cv2
import os

class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file, target_size=(1024, 1024)):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.target_size = target_size
        self.transform = ResizeLongestSide(1024)  # SAM default size
        
        # Load bbox annotations
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })
    
    def resize_with_bbox(self, image, mask, bbox):
        """调整图像、掩码和边界框的大小"""
        h, w = image.shape[:2]
        target_h, target_w = self.target_size
        
        # 计算缩放比例
        scale_x = target_w / w
        scale_y = target_h / h
        
        # 调整图像大小
        resized_image = cv2.resize(image, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
        
        # 调整掩码大小
        if mask is not None:
            resized_mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
        else:
            resized_mask = None
        
        # 调整边界框
        resized_bbox = [
            bbox[0] * scale_x,  # x1
            bbox[1] * scale_y,  # y1
            bbox[2] * scale_x,  # x2
            bbox[3] * scale_y   # y2
        ]
        
        return resized_image, resized_mask, resized_bbox
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        # Load image
        image = cv2.imread(os.path.join(self.image_dir, ann['image']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask_name = ann['image'].replace('.jpg', '_mask.png')
        mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
        mask = mask.astype(np.float32) / 255.0
        
        # 首先将图像调整为统一大小
        image, mask, bbox = self.resize_with_bbox(image, mask, ann['bbox'])
        
        # 准备图像
        original_size = self.target_size
        input_image = self.transform.apply_image(image)
        
        # Convert to float32 and normalize to 0-1 range
        input_image = input_image.astype(np.float32) / 255.0
        
        # Convert to tensor and normalize according to ImageNet stats
        input_image = torch.from_numpy(input_image).permute(2, 0, 1).contiguous()
        
        # Apply ImageNet normalization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        # Prepare bbox
        bbox = self.transform.apply_boxes(np.array([bbox]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
        
        # Prepare mask
        mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
        
        return {
            'image': input_image.float(),
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch
        }

def train_sam(
    model_type='vit_b',
    checkpoint_path='sam_vit_b_01ec64.pth',
    image_dir='./stamps/images',
    mask_dir='./stamps/masks',
    bbox_file='./stamps/annotations.txt',
    output_dir='./checkpoints',
    num_epochs=10,
    batch_size=1,
    learning_rate=1e-5
):
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
    sam_model.to(device)
    
    # Prepare dataset
    dataset = StampDataset(image_dir, mask_dir, bbox_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
    
    # Loss function
    loss_fn = torch.nn.MSELoss()
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            # Move inputs to device
            input_image = batch['image'].to(device)
            original_size = batch['original_size']
            bbox = batch['bbox'].to(device)
            gt_mask = batch['mask'].to(device)
            
            # Print shapes and types for debugging
            if batch_idx == 0 and epoch == 0:
                print(f"Input image shape: {input_image.shape}")
                print(f"Input image type: {input_image.dtype}")
                print(f"Input image range: [{input_image.min():.2f}, {input_image.max():.2f}]")
            
            # Get image embedding (without gradient)
            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
                
                # Get prompt embeddings
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )
            
            # Generate mask prediction
            mask_predictions, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # Upscale masks to original size
            upscaled_masks = sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size[0]
            ).to(device)
            
            # Convert to binary mask
            binary_masks = torch.sigmoid(upscaled_masks)
            
            # Calculate loss
            loss = loss_fn(binary_masks, gt_mask)
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            checkpoint_file = os.path.join(output_dir, f'sam_finetuned_epoch_{epoch+1}.pth')
            torch.save(sam_model.state_dict(), checkpoint_file)
            print(f'Checkpoint saved: {checkpoint_file}')
    
    # Save final model
    final_checkpoint = os.path.join(output_dir, 'sam_finetuned_final.pth')
    torch.save(sam_model.state_dict(), final_checkpoint)
    print(f'Final model saved to {final_checkpoint}')

if __name__ == '__main__':
    # Create output directory if it doesn't exist
    os.makedirs('./checkpoints', exist_ok=True)
    
    # Start training
    train_sam()

import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
import cv2
from pathlib import Path

class SAMPredictor:
    def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
        """
        初始化SAM预测器
        Args:
            checkpoint_path: 模型权重路径
            model_type: 模型类型 ("vit_h", "vit_l", "vit_b")
            device: 使用设备 ("cuda" or "cpu")
        """
        self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
        print(f"Using device: {self.device}")
        
        # 加载模型
        self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
        self.sam_model.to(self.device)
        
        # 创建图像变换器
        self.transform = ResizeLongestSide(1024)
    
    def resize_bbox(self, bbox, original_size, target_size=(1024, 1024)):
        """
        调整边界框坐标以匹配调整大小后的图像
        Args:
            bbox: 原始边界框坐标 [x1, y1, x2, y2]
            original_size: 原始图像尺寸 (height, width)
            target_size: 目标图像尺寸 (height, width)
        Returns:
            resized_bbox: 调整后的边界框坐标
        """
        orig_h, orig_w = original_size
        target_h, target_w = target_size
        
        # 计算缩放比例
        scale_x = target_w / orig_w
        scale_y = target_h / orig_h
        
        # 调整边界框坐标
        x1, y1, x2, y2 = bbox
        resized_bbox = [
            x1 * scale_x,
            y1 * scale_y,
            x2 * scale_x,
            y2 * scale_y
        ]
        
        return resized_bbox
        
    def preprocess_image(self, image):
        """预处理输入图像"""
        # 保存原始尺寸
        original_size = image.shape[:2]
        
        # 确保图像是RGB格式
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
        elif len(image.shape) == 3 and image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
        # 调整图像大小
        image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
        
        # 转换为float32并归一化
        input_image = image.astype(np.float32) / 255.0
        
        # 转换为tensor并添加batch维度
        input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
        
        # 标准化
        mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
        input_image = (input_image - mean) / std
        
        return input_image.to(self.device), original_size, image
        
    def predict(self, image, bbox):
        """
        预测单个图像的分割掩码
        Args:
            image: numpy array 格式的图像
            bbox: [x1, y1, x2, y2] 格式的边界框
        Returns:
            binary_mask: 二值化的分割掩码
            confidence: 预测的置信度
        """
        # 预处理图像
        input_image, original_size, resized_image = self.preprocess_image(image)
        
        # 调整边界框大小
        resized_bbox = self.resize_bbox(bbox, original_size)
        print(resized_bbox, image.shape, resized_image.shape)
        
        # 准备边界框
        bbox_torch = torch.tensor(resized_bbox, dtype=torch.float, device=self.device).unsqueeze(0)
        
        # 获取图像嵌入
        with torch.no_grad():
            image_embedding = self.sam_model.image_encoder(input_image)
            
            # 获取提示嵌入
            sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
                points=None,
                boxes=bbox_torch,
                masks=None,
            )
            
            # 生成掩码预测
            mask_predictions, iou_predictions = self.sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )
            
            # 后处理掩码
            upscaled_masks = self.sam_model.postprocess_masks(
                mask_predictions,
                input_size=input_image.shape[-2:],
                original_size=original_size
            ).to(self.device)
            
            # 转换为二值掩码
            binary_mask = torch.sigmoid(upscaled_masks) > 0.5
            
        return binary_mask[0, 0].cpu().numpy(), iou_predictions[0, 0].item()

def visualize_prediction(image, mask, bbox, confidence, save_path=None):
    """
    可视化预测结果
    Args:
        image: 原始图像
        mask: 预测的掩码
        bbox: 边界框坐标
        confidence: 预测置信度
        save_path: 保存路径(可选)
    """
    # 创建图形
    plt.figure(figsize=(15, 5))
    
    # 显示原始图像
    plt.subplot(131)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    # 绘制边界框
    x1, y1, x2, y2 = map(int, bbox)
    plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
    plt.axis('off')
    
    # 显示预测掩码
    plt.subplot(132)
    plt.imshow(mask, cmap='gray')
    plt.title(f'Predicted Mask\nConfidence: {confidence:.2f}')
    plt.axis('off')
    
    # 显示叠加结果
    plt.subplot(133)
    overlay = image.copy()
    overlay[mask > 0] = overlay[mask > 0] * 0.7 + np.array([0, 255, 0], dtype=np.uint8) * 0.3
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.title('Overlay')
    plt.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"结果已保存到: {save_path}")
    
    plt.show()

def main():
    # 配置参数
    checkpoint_path = "./checkpoints/sam_finetuned_final.pth"  # 使用微调后的模型
    test_image_path = "./stamps/images/sample_0.jpg"
    output_dir = "./predictions"
    
    # 创建输出目录
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # 初始化预测器
    predictor = SAMPredictor(checkpoint_path)
    
    # 读取测试图像
    image = cv2.imread(test_image_path)
    # image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
    
    # 读取边界框(这里使用示例边界框,实际应用中可能需要从标注文件读取)
    with open('./stamps/annotations.txt', 'r') as f:
        first_line = f.readline().strip()
        _, x1, y1, x2, y2 = first_line.split(',')
        bbox = [float(x1), float(y1), float(x2), float(y2)]
        print(bbox)
    
    # 进行预测
    mask, confidence = predictor.predict(image, bbox)
    
    # 可视化结果
    save_path = str(Path(output_dir) / "prediction_result.png")
    visualize_prediction(image, mask, bbox, confidence, save_path)

if __name__ == "__main__":
    main()

运行结果:

在这里插入图片描述



分割线



补充2:

上文提到的是微调decoder部分,下面补充微调encoder部分的代码:

注意事项:

微调encoder需要更多的计算资源和训练时间
需要更大的训练数据集以避免过拟合
建议使用验证集监控性能,防止模型退化
可能需要更多的训练轮次才能收敛

import torch
import numpy as np
from per_segment_anything import sam_model_registry, SamPredictor
from per_segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import cv2
import os
from tqdm import tqdm
import logging
import json
from datetime import datetime
from train_setimage import preprocess
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class StampDataset(Dataset):
    def __init__(self, image_dir, mask_dir, bbox_file, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform if transform else ResizeLongestSide(1024)

        # 加载标注文件
        self.annotations = []
        with open(bbox_file, 'r') as f:
            for line in f:
                img_name, x1, y1, x2, y2 = line.strip().split(',')
                self.annotations.append({
                    'image': img_name,
                    'bbox': [float(x1), float(y1), float(x2), float(y2)]
                })

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        ann = self.annotations[idx]

        # 读取图像
        image_path = os.path.join(self.image_dir, ann['image'])
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # 读取mask
        mask_name = ann['image'].replace('.jpg', '_mask.png')
        mask_path = os.path.join(self.mask_dir, mask_name)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")
        mask = mask.astype(np.float32) / 255.0

        # 准备图像
        original_size = image.shape[:2]
        input_image = self.transform.apply_image(image)
        input_image = input_image.astype(np.float32) / 255.0
        # 转换为tensor并进行ImageNet归一化
        input_image = torch.from_numpy(input_image).permute(2, 0, 1)
        # Use preprocess to handle ImageNet normalization and padding
        input_image = preprocess(input_image)
        print(f"Processed image shape: {input_image.shape}")

        # 准备bbox
        bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
        bbox_torch = torch.tensor(bbox, dtype=torch.float)

        # 准备mask
        mask_torch = torch.from_numpy(mask).float()

        return {
            'image': input_image.float(),
            'original_size': original_size,
            'bbox': bbox_torch,
            'mask': mask_torch,
            'image_path': image_path  # 用于调试
        }


class SAMFineTuner:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.setup_model()
        self.setup_datasets()
        self.setup_training()

        # 创建输出目录
        os.makedirs(config['output_dir'], exist_ok=True)

        # 保存配置
        config_path = os.path.join(config['output_dir'], 'config.json')
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)

    def setup_model(self):
        logger.info(f"Loading SAM model: {self.config['model_type']}")
        self.model = sam_model_registry[self.config['model_type']](
            checkpoint=self.config['checkpoint_path']
        )
        self.model.to(self.device)

    def setup_datasets(self):
        logger.info("Setting up datasets")
        self.train_dataset = StampDataset(
            self.config['train_image_dir'],
            self.config['train_mask_dir'],
            self.config['train_bbox_file']
        )
        # 从训练数据集中按批次加载数据
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=self.config['num_workers'],
            pin_memory=True
        )
        # 验证集
        if self.config.get('val_bbox_file'):
            self.val_dataset = StampDataset(
                self.config['val_image_dir'],
                self.config['val_mask_dir'],
                self.config['val_bbox_file']
            )
            self.val_loader = DataLoader(
                self.val_dataset,
                batch_size=self.config['batch_size'],
                shuffle=False,
                num_workers=self.config['num_workers'],
                pin_memory=True
            )

    def setup_training(self):
        logger.info("Setting up training components")
        # 分别设置encoder和decoder的学习率
        self.optimizer = torch.optim.Adam([
            {'params': self.model.image_encoder.parameters(),
             'lr': self.config['encoder_lr']},
            {'params': self.model.mask_decoder.parameters(),
             'lr': self.config['decoder_lr']}
        ])

        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=True
        )

        self.loss_fn = torch.nn.MSELoss()
        self.scaler = GradScaler()

        # 记录最佳模型
        self.best_loss = float('inf')

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0  # 初始化总损失

        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}')
        for batch_idx, batch in enumerate(pbar):
            # 将数据移到GPU
            input_image = batch['image'].to(self.device)
            bbox = batch['bbox'].to(self.device)
            gt_mask = batch['mask'].to(self.device)

            self.optimizer.zero_grad()

            with autocast():
                # 前向传播
                image_embedding = self.model.image_encoder(input_image)

                with torch.no_grad():
                    sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                        points=None,
                        boxes=bbox,
                        masks=None,
                    )

                mask_predictions, _ = self.model.mask_decoder(
                    image_embeddings=image_embedding,
                    image_pe=self.model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )

                upscaled_masks = self.model.postprocess_masks(
                    mask_predictions,
                    input_size=input_image.shape[-2:],
                    original_size=batch['original_size']
                ).to(self.device)

                binary_masks = torch.sigmoid(upscaled_masks)
                loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))

            # 反向传播
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

        return total_loss / len(self.train_loader)

    @torch.no_grad()
    def validate(self):
        if not hasattr(self, 'val_loader'):
            return None

        self.model.eval()
        total_loss = 0

        for batch in tqdm(self.val_loader, desc='Validating'):
            input_image = batch['image'].to(self.device)
            bbox = batch['bbox'].to(self.device)
            gt_mask = batch['mask'].to(self.device)

            with autocast():
                image_embedding = self.model.image_encoder(input_image)
                sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                    points=None,
                    boxes=bbox,
                    masks=None,
                )

                mask_predictions, _ = self.model.mask_decoder(
                    image_embeddings=image_embedding,
                    image_pe=self.model.prompt_encoder.get_dense_pe(),
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )

                upscaled_masks = self.model.postprocess_masks(
                    mask_predictions,
                    input_size=input_image.shape[-2:],
                    original_size=batch['original_size']
                ).to(self.device)

                binary_masks = torch.sigmoid(upscaled_masks)
                loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))

                total_loss += loss.item()

        return total_loss / len(self.val_loader)


    def save_checkpoint(self, epoch, loss, is_best=False):
        # 保存完整的训练状态(用于恢复训练)
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'config': self.config
        }

        # 保存完整checkpoint
        checkpoint_path = os.path.join(
            self.config['output_dir'],
            f'checkpoint_epoch_{epoch + 1}.pth'
        )
        torch.save(checkpoint, checkpoint_path)

        # 如果是最佳模型,保存兼容格式的模型权重
        if is_best:
            # 保存完整checkpoint
            best_checkpoint_path = os.path.join(self.config['output_dir'], 'best_checkpoint.pth')
            torch.save(checkpoint, best_checkpoint_path)
            
            # 额外保存一个干净的模型权重(兼容原SAM格式)
            best_model_path = os.path.join(self.config['output_dir'], 'best_model_sam_format.pth')
            torch.save(self.model.state_dict(), best_model_path)
            logger.info(f"Saved best model with loss: {loss:.4f}")

    def train(self):
        logger.info("Starting training")
        for epoch in range(self.config['num_epochs']):
            train_loss = self.train_epoch(epoch)
            logger.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}")

            val_loss = self.validate()
            if val_loss is not None:
                logger.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
                self.scheduler.step(val_loss)
                is_best = val_loss < self.best_loss
                if is_best:
                    self.best_loss = val_loss
            else:
                is_best = False
                self.scheduler.step(train_loss)

            if (epoch + 1) % self.config['save_interval'] == 0:
                self.save_checkpoint(
                    epoch,
                    val_loss if val_loss is not None else train_loss,
                    is_best
                )
        
        # 训练结束后保存最终的兼容格式模型
        final_model_path = os.path.join(self.config['output_dir'], 'final_model_sam_format.pth')
        torch.save(self.model.state_dict(), final_model_path)
        logger.info(f"Saved final model in SAM-compatible format: {final_model_path}")


def main():
    # 训练配置
    config = {
        'model_type': 'vit_b',
        'checkpoint_path': './checkpoints/sam_vit_b_01ec64.pth',
        'train_image_dir': './stamps/images',
        'train_mask_dir': './stamps/masks',
        'train_bbox_file': './stamps/annotations.txt',

        'val_image_dir': './stamps/val_images',
        'val_mask_dir': './stamps/val_masks',
        'val_bbox_file': './stamps/val_annotations.txt',
        'output_dir': f'./outputs/sam_finetune_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
        'num_epochs': 1,
        'batch_size': 1,
        'num_workers': 4,
        'encoder_lr': 1e-6,
        'decoder_lr': 1e-5,
        'save_interval': 5
    }

    # 创建训练器并开始训练
    trainer = SAMFineTuner(config)
    trainer.train()


if __name__ == '__main__':
    main()

评论 38
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jcfszxc

赏我点铜板买喵粮吃吧!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值