如何微调SAM模型:从环境配置到训练实现的完整指南
如何微调SAM模型:从环境配置到训练实现的完整指南
补充1:
(2025年1月2日)
很多朋友来问数据标注是什么格式,因此添加补充1作解答。
运行代码末尾提供的demo,既可以生成标注格式的demo示例。
python sam-data-setup.py
数据集目录下,放images文件夹、masks文件夹、和annotations.txt,
images里放原始图片,这里随机生成的。可在这个文件夹里放入自己的数据。
images里放对应的掩码图像,并且对应更改文件后缀名,在这个文件夹里放入自己数据对应的标签掩码图像。
annotations.txt里放图片对应的检测框坐标信息。
引言
Segment Anything Model (SAM) 是 Meta AI 推出的一个强大的图像分割模型。尽管预训练模型表现优秀,但在特定领域(如医疗影像、工业检测等)可能需要进行微调以获得更好的性能。本文将详细介绍如何微调 SAM 模型,包括环境配置、数据准备和训练实现。
目录
- 环境配置
- 项目结构
- 数据准备
- 模型微调
- 训练过程
- 注意事项和优化建议
1. 环境配置
首先,我们需要配置正确的 Python 环境和依赖包。推荐使用虚拟环境来管理依赖:
# 创建并激活虚拟环境
python -m venv sam_env
# Windows:
.\sam_env\Scripts\activate
# Linux/Mac:
source sam_env/bin/activate
# 安装依赖
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install opencv-python
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install numpy matplotlib
# 下载预训练模型
# Windows PowerShell:
Invoke-WebRequest -Uri "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" -OutFile "sam_vit_b_01ec64.pth"
# Linux/Mac:
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
2. 项目结构
推荐的项目结构如下:
project_root/
├── stamps/
│ ├── images/ # 训练图像
│ ├── masks/ # 分割掩码
│ └── annotations.txt # 边界框标注
├── checkpoints/ # 模型检查点
├── setup_sam_data.py # 数据准备脚本
└── sam_finetune.py # 训练脚本
3. 数据准备
为了训练模型,我们需要准备以下数据:
- 训练图像
- 分割掩码
- 边界框标注
以下是数据准备脚本的实现:
import os
import numpy as np
import cv2
from pathlib import Path
def create_project_structure():
"""创建项目所需的目录结构"""
directories = [
'./stamps/images',
'./stamps/masks',
'./checkpoints'
]
for dir_path in directories:
Path(dir_path).mkdir(parents=True, exist_ok=True)
return directories
def create_sample_data(num_samples=5):
"""创建示例训练数据"""
annotations = []
for i in range(num_samples):
# 创建示例图像
image = np.ones((500, 500, 3), dtype=np.uint8) * 255
center_x = np.random.randint(150, 350)
center_y = np.random.randint(150, 350)
radius = np.random.randint(50, 100)
# 绘制对象
cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
# 创建掩码
mask = np.zeros((500, 500), dtype=np.uint8)
cv2.circle(mask, (center_x, center_y), radius, 255, -1)
# 保存文件
cv2.imwrite(f'./stamps/images/sample_{i}.jpg', image)
cv2.imwrite(f'./stamps/masks/sample_{i}_mask.png', mask)
# 计算边界框
x1 = max(0, center_x - radius)
y1 = max(0, center_y - radius)
x2 = min(500, center_x + radius)
y2 = min(500, center_y + radius)
annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
# 保存标注文件
with open('./stamps/annotations.txt', 'w') as f:
f.writelines(annotations)
4. 模型微调
4.1 数据集类实现
首先实现自定义数据集类:
class StampDataset(Dataset):
def __init__(self, image_dir, mask_dir, bbox_file):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = ResizeLongestSide(1024)
# 加载标注
self.annotations = []
with open(bbox_file, 'r') as f:
for line in f:
img_name, x1, y1, x2, y2 = line.strip().split(',')
self.annotations.append({
'image': img_name,
'bbox': [float(x1), float(y1), float(x2), float(y2)]
})
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
ann = self.annotations[idx]
# 加载和预处理图像
image = cv2.imread(os.path.join(self.image_dir, ann['image']))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(os.path.join(self.mask_dir,
ann['image'].replace('.jpg', '_mask.png')),
cv2.IMREAD_GRAYSCALE)
mask = mask.astype(np.float32) / 255.0
# 图像处理
original_size = image.shape[:2]
input_image = self.transform.apply_image(image)
input_image = input_image.astype(np.float32) / 255.0
input_image = torch.from_numpy(input_image).permute(2, 0, 1)
# 标准化
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
input_image = (input_image - mean) / std
# 处理边界框和掩码
bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
return {
'image': input_image.float(),
'original_size': original_size,
'bbox': bbox_torch,
'mask': mask_torch
}
4.2 训练函数实现
训练函数的核心实现:
def train_sam(
model_type='vit_b',
checkpoint_path='sam_vit_b_01ec64.pth',
num_epochs=10,
batch_size=1,
learning_rate=1e-5
):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型
sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam_model.to(device)
# 准备数据和优化器
dataset = StampDataset(image_dir='./stamps/images',
mask_dir='./stamps/masks',
bbox_file='./stamps/annotations.txt')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
loss_fn = torch.nn.MSELoss()
# 训练循环
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
# 准备数据
input_image = batch['image'].to(device)
original_size = batch['original_size']
bbox = batch['bbox'].to(device)
gt_mask = batch['mask'].to(device)
# 前向传播
with torch.no_grad():
image_embedding = sam_model.image_encoder(input_image)
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=None,
boxes=bbox,
masks=None,
)
# 生成预测
mask_predictions, _ = sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
# 后处理
upscaled_masks = sam_model.postprocess_masks(
mask_predictions,
input_size=input_image.shape[-2:],
original_size=original_size[0]
).to(device)
binary_masks = torch.sigmoid(upscaled_masks)
# 计算损失并优化
loss = loss_fn(binary_masks, gt_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
# 输出epoch统计
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
# 保存检查点
if (epoch + 1) % 5 == 0:
checkpoint_file = f'./checkpoints/sam_finetuned_epoch_{epoch+1}.pth'
torch.save(sam_model.state_dict(), checkpoint_file)
5. 训练过程
完整的训练过程如下:
- 准备环境和数据:
python setup_sam_data.py
- 开始训练:
python sam_finetune.py
6. 注意事项和优化建议
-
数据预处理:
- 确保图像数据类型正确(float32)
- 进行适当的数据标准化
- 注意图像尺寸的一致性
-
训练优化:
- 根据GPU内存调整batch_size
- 适当调整学习率
- 考虑使用学习率调度器
- 添加验证集评估
- 实现早停机制
-
可能的改进:
- 添加数据增强
- 使用不同的损失函数
- 实现多GPU训练
- 添加训练过程可视化
- 实现模型验证和测试
7. 模型预测和可视化
在完成模型微调后,我们需要一个方便的方式来使用模型进行预测并可视化结果。以下是完整的实现:
7.1 预测器类实现
首先,我们封装一个预测器类,用于处理模型加载、图像预处理和预测:
class SAMPredictor:
def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
self.sam_model.to(self.device)
self.transform = ResizeLongestSide(1024)
这个类提供了简单的接口来加载模型并进行预测。主要功能包括:
- 模型加载和设备配置
- 图像预处理
- 掩码预测
- 后处理优化
7.2 可视化函数
为了better展示预测结果,我们实现了一个可视化函数:
def visualize_prediction(image, mask, bbox, confidence, save_path=None):
plt.figure(figsize=(15, 5))
# 显示原始图像、预测掩码和叠加结果
...
这个函数可以同时显示:
- 原始图像(带边界框)
- 预测的分割掩码
- 结果叠加视图
7.3 使用示例
以下是如何使用这些工具的完整示例:
# 初始化预测器
predictor = SAMPredictor("./checkpoints/sam_finetuned_final.pth")
# 读取测试图像
image = cv2.imread("test_image.jpg")
bbox = [x1, y1, x2, y2] # 边界框坐标
# 预测
mask, confidence = predictor.predict(image, bbox)
# 可视化
visualize_prediction(image, mask, bbox, confidence, "result.png")
7.4 注意事项
在使用预测器时,需要注意以下几点:
-
输入图像处理:
- 确保图像格式正确(RGB)
- 注意图像尺寸的一致性
- 正确的数据类型和范围
-
边界框格式:
- 使用 [x1, y1, x2, y2] 格式
- 确保坐标在图像范围内
- 坐标值为浮点数
-
性能优化:
- 批处理预测
- GPU 内存管理
- 结果缓存
7.5 可能的改进
- 批量处理功能:
def predict_batch(self, images, bboxes):
results = []
for image, bbox in zip(images, bboxes):
mask, conf = self.predict(image, bbox)
results.append((mask, conf))
return results
- 多边界框支持:
def predict_multiple_boxes(self, image, bboxes):
masks = []
for bbox in bboxes:
mask, _ = self.predict(image, bbox)
masks.append(mask)
return np.stack(masks)
- 交互式可视化:
def interactive_visualization(image, predictor):
def onclick(event):
if event.button == 1: # 左键点击
bbox = [event.xdata-50, event.ydata-50,
event.xdata+50, event.ydata+50]
mask, _ = predictor.predict(image, bbox)
visualize_prediction(image, mask, bbox)
fig, ax = plt.subplots()
ax.imshow(image)
fig.canvas.mpl_connect('button_press_event', onclick)
plt.show()
这些工具和示例可以帮助你更好地理解和使用微调后的SAM模型。根据具体需求,你可以进一步优化和扩展这些功能。
结论
通过以上步骤,我们实现了SAM模型的微调过程。这个实现可以作为基础,根据具体需求进行优化和改进。在实际应用中,可能需要根据具体任务调整数据预处理、损失函数和训练策略。
建议在使用时注意以下几点:
- 确保训练数据质量
- 合理设置训练参数
- 定期保存检查点
- 监控训练过程
- 适当使用数据增强
希望这个教程对你的项目有所帮助!如果有任何问题,欢迎讨论和交流。
参考资料
- Segment Anything 官方仓库
- PyTorch 文档
- SAM 论文:Segment Anything
- torchvision 文档
快速部署:
下载这三个代码,配置好运行环境,依次运行:
# sam-data-setup.py
import os
import numpy as np
import cv2
from pathlib import Path
def create_project_structure():
"""创建项目所需的目录结构"""
# 创建主目录
directories = [
'./stamps/images',
'./stamps/masks',
'./checkpoints'
]
for dir_path in directories:
Path(dir_path).mkdir(parents=True, exist_ok=True)
return directories
def create_sample_data(num_samples=5):
"""创建示例训练数据"""
# 创建示例图像和掩码
annotations = []
for i in range(num_samples):
# 创建示例图像 (500x500)
image = np.ones((500, 500, 3), dtype=np.uint8) * 255
# 添加一个示例印章 (随机位置的圆形)
center_x = np.random.randint(150, 350)
center_y = np.random.randint(150, 350)
radius = np.random.randint(50, 100)
# 绘制印章
cv2.circle(image, (center_x, center_y), radius, (0, 0, 255), -1)
# 创建对应的掩码
mask = np.zeros((500, 500), dtype=np.uint8)
cv2.circle(mask, (center_x, center_y), radius, 255, -1)
# 保存图像和掩码
image_path = f'./stamps/images/sample_{i}.jpg'
mask_path = f'./stamps/masks/sample_{i}_mask.png'
cv2.imwrite(image_path, image)
cv2.imwrite(mask_path, mask)
# 计算边界框
x1 = max(0, center_x - radius)
y1 = max(0, center_y - radius)
x2 = min(500, center_x + radius)
y2 = min(500, center_y + radius)
# 添加到注释列表
annotations.append(f'sample_{i}.jpg,{x1},{y1},{x2},{y2}\n')
# 保存注释文件
with open('./stamps/annotations.txt', 'w') as f:
f.writelines(annotations)
def main():
print("开始创建项目结构...")
directories = create_project_structure()
for dir_path in directories:
print(f"创建目录: {dir_path}")
print("\n创建示例训练数据...")
create_sample_data()
print("示例数据创建完成!")
print("\n项目结构:")
for root, dirs, files in os.walk('./stamps'):
level = root.replace('./stamps', '').count(os.sep)
indent = ' ' * 4 * level
print(f"{indent}{os.path.basename(root)}/")
sub_indent = ' ' * 4 * (level + 1)
for f in files:
print(f"{sub_indent}{f}")
if __name__ == '__main__':
main()
# sam_finetune_decoder.py
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
import cv2
import os
class StampDataset(Dataset):
def __init__(self, image_dir, mask_dir, bbox_file, target_size=(1024, 1024)):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.target_size = target_size
self.transform = ResizeLongestSide(1024) # SAM default size
# Load bbox annotations
self.annotations = []
with open(bbox_file, 'r') as f:
for line in f:
img_name, x1, y1, x2, y2 = line.strip().split(',')
self.annotations.append({
'image': img_name,
'bbox': [float(x1), float(y1), float(x2), float(y2)]
})
def resize_with_bbox(self, image, mask, bbox):
"""调整图像、掩码和边界框的大小"""
h, w = image.shape[:2]
target_h, target_w = self.target_size
# 计算缩放比例
scale_x = target_w / w
scale_y = target_h / h
# 调整图像大小
resized_image = cv2.resize(image, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# 调整掩码大小
if mask is not None:
resized_mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
else:
resized_mask = None
# 调整边界框
resized_bbox = [
bbox[0] * scale_x, # x1
bbox[1] * scale_y, # y1
bbox[2] * scale_x, # x2
bbox[3] * scale_y # y2
]
return resized_image, resized_mask, resized_bbox
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
ann = self.annotations[idx]
# Load image
image = cv2.imread(os.path.join(self.image_dir, ann['image']))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Load mask
mask_name = ann['image'].replace('.jpg', '_mask.png')
mask = cv2.imread(os.path.join(self.mask_dir, mask_name), cv2.IMREAD_GRAYSCALE)
mask = mask.astype(np.float32) / 255.0
# 首先将图像调整为统一大小
image, mask, bbox = self.resize_with_bbox(image, mask, ann['bbox'])
# 准备图像
original_size = self.target_size
input_image = self.transform.apply_image(image)
# Convert to float32 and normalize to 0-1 range
input_image = input_image.astype(np.float32) / 255.0
# Convert to tensor and normalize according to ImageNet stats
input_image = torch.from_numpy(input_image).permute(2, 0, 1).contiguous()
# Apply ImageNet normalization
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
input_image = (input_image - mean) / std
# Prepare bbox
bbox = self.transform.apply_boxes(np.array([bbox]), original_size)[0]
bbox_torch = torch.tensor(bbox, dtype=torch.float).unsqueeze(0)
# Prepare mask
mask_torch = torch.from_numpy(mask).float().unsqueeze(0)
return {
'image': input_image.float(),
'original_size': original_size,
'bbox': bbox_torch,
'mask': mask_torch
}
def train_sam(
model_type='vit_b',
checkpoint_path='sam_vit_b_01ec64.pth',
image_dir='./stamps/images',
mask_dir='./stamps/masks',
bbox_file='./stamps/annotations.txt',
output_dir='./checkpoints',
num_epochs=10,
batch_size=1,
learning_rate=1e-5
):
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Initialize model
sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam_model.to(device)
# Prepare dataset
dataset = StampDataset(image_dir, mask_dir, bbox_file)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Setup optimizer
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=learning_rate)
# Loss function
loss_fn = torch.nn.MSELoss()
# Training loop
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, batch in enumerate(dataloader):
# Move inputs to device
input_image = batch['image'].to(device)
original_size = batch['original_size']
bbox = batch['bbox'].to(device)
gt_mask = batch['mask'].to(device)
# Print shapes and types for debugging
if batch_idx == 0 and epoch == 0:
print(f"Input image shape: {input_image.shape}")
print(f"Input image type: {input_image.dtype}")
print(f"Input image range: [{input_image.min():.2f}, {input_image.max():.2f}]")
# Get image embedding (without gradient)
with torch.no_grad():
image_embedding = sam_model.image_encoder(input_image)
# Get prompt embeddings
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=None,
boxes=bbox,
masks=None,
)
# Generate mask prediction
mask_predictions, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
# Upscale masks to original size
upscaled_masks = sam_model.postprocess_masks(
mask_predictions,
input_size=input_image.shape[-2:],
original_size=original_size[0]
).to(device)
# Convert to binary mask
binary_masks = torch.sigmoid(upscaled_masks)
# Calculate loss
loss = loss_fn(binary_masks, gt_mask)
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0:
print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch} completed. Average Loss: {avg_loss:.4f}')
# Save checkpoint
if (epoch + 1) % 5 == 0:
checkpoint_file = os.path.join(output_dir, f'sam_finetuned_epoch_{epoch+1}.pth')
torch.save(sam_model.state_dict(), checkpoint_file)
print(f'Checkpoint saved: {checkpoint_file}')
# Save final model
final_checkpoint = os.path.join(output_dir, 'sam_finetuned_final.pth')
torch.save(sam_model.state_dict(), final_checkpoint)
print(f'Final model saved to {final_checkpoint}')
if __name__ == '__main__':
# Create output directory if it doesn't exist
os.makedirs('./checkpoints', exist_ok=True)
# Start training
train_sam()
import torch
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.transforms import ResizeLongestSide
import cv2
from pathlib import Path
class SAMPredictor:
def __init__(self, checkpoint_path, model_type="vit_b", device="cuda"):
"""
初始化SAM预测器
Args:
checkpoint_path: 模型权重路径
model_type: 模型类型 ("vit_h", "vit_l", "vit_b")
device: 使用设备 ("cuda" or "cpu")
"""
self.device = torch.device(device if torch.cuda.is_available() and device == "cuda" else "cpu")
print(f"Using device: {self.device}")
# 加载模型
self.sam_model = sam_model_registry[model_type](checkpoint=checkpoint_path)
self.sam_model.to(self.device)
# 创建图像变换器
self.transform = ResizeLongestSide(1024)
def resize_bbox(self, bbox, original_size, target_size=(1024, 1024)):
"""
调整边界框坐标以匹配调整大小后的图像
Args:
bbox: 原始边界框坐标 [x1, y1, x2, y2]
original_size: 原始图像尺寸 (height, width)
target_size: 目标图像尺寸 (height, width)
Returns:
resized_bbox: 调整后的边界框坐标
"""
orig_h, orig_w = original_size
target_h, target_w = target_size
# 计算缩放比例
scale_x = target_w / orig_w
scale_y = target_h / orig_h
# 调整边界框坐标
x1, y1, x2, y2 = bbox
resized_bbox = [
x1 * scale_x,
y1 * scale_y,
x2 * scale_x,
y2 * scale_y
]
return resized_bbox
def preprocess_image(self, image):
"""预处理输入图像"""
# 保存原始尺寸
original_size = image.shape[:2]
# 确保图像是RGB格式
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
elif len(image.shape) == 3 and image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 调整图像大小
image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
# 转换为float32并归一化
input_image = image.astype(np.float32) / 255.0
# 转换为tensor并添加batch维度
input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
# 标准化
mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
input_image = (input_image - mean) / std
return input_image.to(self.device), original_size, image
def predict(self, image, bbox):
"""
预测单个图像的分割掩码
Args:
image: numpy array 格式的图像
bbox: [x1, y1, x2, y2] 格式的边界框
Returns:
binary_mask: 二值化的分割掩码
confidence: 预测的置信度
"""
# 预处理图像
input_image, original_size, resized_image = self.preprocess_image(image)
# 调整边界框大小
resized_bbox = self.resize_bbox(bbox, original_size)
print(resized_bbox, image.shape, resized_image.shape)
# 准备边界框
bbox_torch = torch.tensor(resized_bbox, dtype=torch.float, device=self.device).unsqueeze(0)
# 获取图像嵌入
with torch.no_grad():
image_embedding = self.sam_model.image_encoder(input_image)
# 获取提示嵌入
sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
points=None,
boxes=bbox_torch,
masks=None,
)
# 生成掩码预测
mask_predictions, iou_predictions = self.sam_model.mask_decoder(
image_embeddings=image_embedding,
image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
# 后处理掩码
upscaled_masks = self.sam_model.postprocess_masks(
mask_predictions,
input_size=input_image.shape[-2:],
original_size=original_size
).to(self.device)
# 转换为二值掩码
binary_mask = torch.sigmoid(upscaled_masks) > 0.5
return binary_mask[0, 0].cpu().numpy(), iou_predictions[0, 0].item()
def visualize_prediction(image, mask, bbox, confidence, save_path=None):
"""
可视化预测结果
Args:
image: 原始图像
mask: 预测的掩码
bbox: 边界框坐标
confidence: 预测置信度
save_path: 保存路径(可选)
"""
# 创建图形
plt.figure(figsize=(15, 5))
# 显示原始图像
plt.subplot(131)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.title('Original Image')
# 绘制边界框
x1, y1, x2, y2 = map(int, bbox)
plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
plt.axis('off')
# 显示预测掩码
plt.subplot(132)
plt.imshow(mask, cmap='gray')
plt.title(f'Predicted Mask\nConfidence: {confidence:.2f}')
plt.axis('off')
# 显示叠加结果
plt.subplot(133)
overlay = image.copy()
overlay[mask > 0] = overlay[mask > 0] * 0.7 + np.array([0, 255, 0], dtype=np.uint8) * 0.3
plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
plt.title('Overlay')
plt.axis('off')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
print(f"结果已保存到: {save_path}")
plt.show()
def main():
# 配置参数
checkpoint_path = "./checkpoints/sam_finetuned_final.pth" # 使用微调后的模型
test_image_path = "./stamps/images/sample_0.jpg"
output_dir = "./predictions"
# 创建输出目录
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 初始化预测器
predictor = SAMPredictor(checkpoint_path)
# 读取测试图像
image = cv2.imread(test_image_path)
# image = cv2.resize(image, (1024, 1024), interpolation=cv2.INTER_LINEAR)
# 读取边界框(这里使用示例边界框,实际应用中可能需要从标注文件读取)
with open('./stamps/annotations.txt', 'r') as f:
first_line = f.readline().strip()
_, x1, y1, x2, y2 = first_line.split(',')
bbox = [float(x1), float(y1), float(x2), float(y2)]
print(bbox)
# 进行预测
mask, confidence = predictor.predict(image, bbox)
# 可视化结果
save_path = str(Path(output_dir) / "prediction_result.png")
visualize_prediction(image, mask, bbox, confidence, save_path)
if __name__ == "__main__":
main()
运行结果:
分割线
补充2:
上文提到的是微调decoder部分,下面补充微调encoder部分的代码:
注意事项:
微调encoder需要更多的计算资源和训练时间
需要更大的训练数据集以避免过拟合
建议使用验证集监控性能,防止模型退化
可能需要更多的训练轮次才能收敛
import torch
import numpy as np
from per_segment_anything import sam_model_registry, SamPredictor
from per_segment_anything.utils.transforms import ResizeLongestSide
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import cv2
import os
from tqdm import tqdm
import logging
import json
from datetime import datetime
from train_setimage import preprocess
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class StampDataset(Dataset):
def __init__(self, image_dir, mask_dir, bbox_file, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform if transform else ResizeLongestSide(1024)
# 加载标注文件
self.annotations = []
with open(bbox_file, 'r') as f:
for line in f:
img_name, x1, y1, x2, y2 = line.strip().split(',')
self.annotations.append({
'image': img_name,
'bbox': [float(x1), float(y1), float(x2), float(y2)]
})
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
ann = self.annotations[idx]
# 读取图像
image_path = os.path.join(self.image_dir, ann['image'])
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 读取mask
mask_name = ann['image'].replace('.jpg', '_mask.png')
mask_path = os.path.join(self.mask_dir, mask_name)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"Failed to load mask: {mask_path}")
mask = mask.astype(np.float32) / 255.0
# 准备图像
original_size = image.shape[:2]
input_image = self.transform.apply_image(image)
input_image = input_image.astype(np.float32) / 255.0
# 转换为tensor并进行ImageNet归一化
input_image = torch.from_numpy(input_image).permute(2, 0, 1)
# Use preprocess to handle ImageNet normalization and padding
input_image = preprocess(input_image)
print(f"Processed image shape: {input_image.shape}")
# 准备bbox
bbox = self.transform.apply_boxes(np.array([ann['bbox']]), original_size)[0]
bbox_torch = torch.tensor(bbox, dtype=torch.float)
# 准备mask
mask_torch = torch.from_numpy(mask).float()
return {
'image': input_image.float(),
'original_size': original_size,
'bbox': bbox_torch,
'mask': mask_torch,
'image_path': image_path # 用于调试
}
class SAMFineTuner:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.setup_model()
self.setup_datasets()
self.setup_training()
# 创建输出目录
os.makedirs(config['output_dir'], exist_ok=True)
# 保存配置
config_path = os.path.join(config['output_dir'], 'config.json')
with open(config_path, 'w') as f:
json.dump(config, f, indent=4)
def setup_model(self):
logger.info(f"Loading SAM model: {self.config['model_type']}")
self.model = sam_model_registry[self.config['model_type']](
checkpoint=self.config['checkpoint_path']
)
self.model.to(self.device)
def setup_datasets(self):
logger.info("Setting up datasets")
self.train_dataset = StampDataset(
self.config['train_image_dir'],
self.config['train_mask_dir'],
self.config['train_bbox_file']
)
# 从训练数据集中按批次加载数据
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.config['batch_size'],
shuffle=True,
num_workers=self.config['num_workers'],
pin_memory=True
)
# 验证集
if self.config.get('val_bbox_file'):
self.val_dataset = StampDataset(
self.config['val_image_dir'],
self.config['val_mask_dir'],
self.config['val_bbox_file']
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.config['batch_size'],
shuffle=False,
num_workers=self.config['num_workers'],
pin_memory=True
)
def setup_training(self):
logger.info("Setting up training components")
# 分别设置encoder和decoder的学习率
self.optimizer = torch.optim.Adam([
{'params': self.model.image_encoder.parameters(),
'lr': self.config['encoder_lr']},
{'params': self.model.mask_decoder.parameters(),
'lr': self.config['decoder_lr']}
])
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
verbose=True
)
self.loss_fn = torch.nn.MSELoss()
self.scaler = GradScaler()
# 记录最佳模型
self.best_loss = float('inf')
def train_epoch(self, epoch):
self.model.train()
total_loss = 0 # 初始化总损失
pbar = tqdm(self.train_loader, desc=f'Epoch {epoch + 1}')
for batch_idx, batch in enumerate(pbar):
# 将数据移到GPU
input_image = batch['image'].to(self.device)
bbox = batch['bbox'].to(self.device)
gt_mask = batch['mask'].to(self.device)
self.optimizer.zero_grad()
with autocast():
# 前向传播
image_embedding = self.model.image_encoder(input_image)
with torch.no_grad():
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=None,
boxes=bbox,
masks=None,
)
mask_predictions, _ = self.model.mask_decoder(
image_embeddings=image_embedding,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
upscaled_masks = self.model.postprocess_masks(
mask_predictions,
input_size=input_image.shape[-2:],
original_size=batch['original_size']
).to(self.device)
binary_masks = torch.sigmoid(upscaled_masks)
loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
# 反向传播
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
total_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
return total_loss / len(self.train_loader)
@torch.no_grad()
def validate(self):
if not hasattr(self, 'val_loader'):
return None
self.model.eval()
total_loss = 0
for batch in tqdm(self.val_loader, desc='Validating'):
input_image = batch['image'].to(self.device)
bbox = batch['bbox'].to(self.device)
gt_mask = batch['mask'].to(self.device)
with autocast():
image_embedding = self.model.image_encoder(input_image)
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=None,
boxes=bbox,
masks=None,
)
mask_predictions, _ = self.model.mask_decoder(
image_embeddings=image_embedding,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
upscaled_masks = self.model.postprocess_masks(
mask_predictions,
input_size=input_image.shape[-2:],
original_size=batch['original_size']
).to(self.device)
binary_masks = torch.sigmoid(upscaled_masks)
loss = self.loss_fn(binary_masks, gt_mask.unsqueeze(1))
total_loss += loss.item()
return total_loss / len(self.val_loader)
def save_checkpoint(self, epoch, loss, is_best=False):
# 保存完整的训练状态(用于恢复训练)
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'loss': loss,
'config': self.config
}
# 保存完整checkpoint
checkpoint_path = os.path.join(
self.config['output_dir'],
f'checkpoint_epoch_{epoch + 1}.pth'
)
torch.save(checkpoint, checkpoint_path)
# 如果是最佳模型,保存兼容格式的模型权重
if is_best:
# 保存完整checkpoint
best_checkpoint_path = os.path.join(self.config['output_dir'], 'best_checkpoint.pth')
torch.save(checkpoint, best_checkpoint_path)
# 额外保存一个干净的模型权重(兼容原SAM格式)
best_model_path = os.path.join(self.config['output_dir'], 'best_model_sam_format.pth')
torch.save(self.model.state_dict(), best_model_path)
logger.info(f"Saved best model with loss: {loss:.4f}")
def train(self):
logger.info("Starting training")
for epoch in range(self.config['num_epochs']):
train_loss = self.train_epoch(epoch)
logger.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}")
val_loss = self.validate()
if val_loss is not None:
logger.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")
self.scheduler.step(val_loss)
is_best = val_loss < self.best_loss
if is_best:
self.best_loss = val_loss
else:
is_best = False
self.scheduler.step(train_loss)
if (epoch + 1) % self.config['save_interval'] == 0:
self.save_checkpoint(
epoch,
val_loss if val_loss is not None else train_loss,
is_best
)
# 训练结束后保存最终的兼容格式模型
final_model_path = os.path.join(self.config['output_dir'], 'final_model_sam_format.pth')
torch.save(self.model.state_dict(), final_model_path)
logger.info(f"Saved final model in SAM-compatible format: {final_model_path}")
def main():
# 训练配置
config = {
'model_type': 'vit_b',
'checkpoint_path': './checkpoints/sam_vit_b_01ec64.pth',
'train_image_dir': './stamps/images',
'train_mask_dir': './stamps/masks',
'train_bbox_file': './stamps/annotations.txt',
'val_image_dir': './stamps/val_images',
'val_mask_dir': './stamps/val_masks',
'val_bbox_file': './stamps/val_annotations.txt',
'output_dir': f'./outputs/sam_finetune_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
'num_epochs': 1,
'batch_size': 1,
'num_workers': 4,
'encoder_lr': 1e-6,
'decoder_lr': 1e-5,
'save_interval': 5
}
# 创建训练器并开始训练
trainer = SAMFineTuner(config)
trainer.train()
if __name__ == '__main__':
main()