使用深度学习框架PyTorch并使用预训练的ResNet模型作为基础来完成任务 对仪表盘指针识别数据集训练及应用 也可以Yolo算法进行训练识别
文章目录
仪表盘指针识别数据集
说明:
7000+张图,已标注txt格式,共4个类别
训练集验证集测试集按7155:525:59划分的
类别:
①base
②end
③start
④tip
仅供参考的建立代码,用于处理和训练仪表盘指针识别数据集。使用深度学习框架PyTorch来完成任务。
任务概述
- 目标:识别仪表盘指针的4个类别(
base
,end
,start
,tip
)。 - 数据集:
- 已标注为txt格式,包含7000+张图像。
- 数据划分:训练集(7155张)、验证集(525张)、测试集(59张)。
- 模型:使用预训练的卷积神经网络(如ResNet或YOLO)进行目标检测或关键点定位。
代码实现
1. 导入必要的库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
2. 定义自定义数据集类
我们需要一个Dataset
类来加载图像和对应的标注文件(txt格式)。
class PointerDataset(Dataset):
def __init__(self, img_dir, label_dir, transform=None):
self.img_dir = img_dir
self.label_dir = label_dir
self.transform = transform
self.image_files = os.listdir(img_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.image_files[idx])
label_path = os.path.join(self.label_dir, self.image_files[idx].replace('.jpg', '.txt'))
# 加载图像
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
# 加载标签
with open(label_path, 'r') as f:
lines = f.readlines()
labels = []
for line in lines:
class_id, x, y = map(float, line.strip().split())
labels.append([class_id, x, y]) # [类别, x坐标, y坐标]
return image, torch.tensor(labels)
3. 数据预处理与加载
定义图像的预处理操作,并加载训练集、验证集和测试集。
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
# 创建数据集
train_dataset = PointerDataset(img_dir="train_images", label_dir="train_labels", transform=transform)
val_dataset = PointerDataset(img_dir="val_images", label_dir="val_labels", transform=transform)
test_dataset = PointerDataset(img_dir="test_images", label_dir="test_labels", transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
4. 定义模型
我们使用预训练的ResNet模型作为基础,并在最后添加一个全连接层来预测4个类别的位置。
class PointerModel(nn.Module):
def __init__(self):
super(PointerModel, self).__init__()
self.base_model = models.resnet18(pretrained=True)
self.base_model.fc = nn.Linear(self.base_model.fc.in_features, 4 * 3) # 4个类别,每个类别有x, y坐标
def forward(self, x):
return self.base_model(x)
5. 定义损失函数和优化器
我们使用均方误差(MSE)作为损失函数,因为它适合回归任务。
model = PointerModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
6. 训练模型
训练模型并在验证集上评估性能。
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels.view(-1, 12)) # 将标签展平为(batch_size, 12)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
# 验证模型
validate_model(model, val_loader, criterion)
def validate_model(model, val_loader, criterion):
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels.view(-1, 12))
val_loss += loss.item()
print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
7. 测试模型
在测试集上评估模型性能。
def test_model(model, test_loader):
model.eval()
test_loss = 0.0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels.view(-1, 12))
test_loss += loss.item()
print(f"Test Loss: {test_loss/len(test_loader):.4f}")
8. 主程序
运行训练、验证和测试流程。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 训练模型
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
# 测试模型
test_model(model, test_loader)
总结
以上代码展示了如何使用PyTorch处理仪表盘指针识别数据集,并训练一个基于ResNet的模型来预测指针的关键点。您可以根据实际需求调整模型结构、损失函数和超参数。