基于CNN的食物图像分类:最优模型测试与应用实战
在计算机视觉领域,图像分类任务一直是研究的热点和基础应用之一。本文将分享一个基于卷积神经网络(CNN)实现的食物图像分类项目,展示如何使用训练好的最优模型对测试数据进行评估,并对结果进行分析。
项目背景与目标
随着人们生活水平的提高,对饮食健康和营养的关注度日益增加。在智能餐饮、健康管理等场景中,准确识别食物种类显得尤为重要。通过构建一个食物图像分类模型,能够快速、准确地判断食物类别,为后续的热量计算、饮食分析等功能提供数据支持。本项目旨在训练一个高精度的CNN模型,并使用最优模型对测试数据进行分类,验证模型的泛化能力和准确性。
数据预处理
在数据预处理阶段,我们针对训练集和验证集采用了不同的策略。训练集数据经过一系列增强操作,包括随机旋转、中心裁剪、水平和垂直翻转、颜色抖动以及随机灰度化等,这些操作不仅扩充了数据集的多样性,还能有效增强模型的鲁棒性,避免过拟合问题。例如,transforms.RandomRotation(45)
使图像在-45到45度之间随机旋转,模拟不同拍摄角度的图像;transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1)
改变图像的亮度、对比度、饱和度和色调,增加图像的视觉变化。而验证集数据仅进行了尺寸调整、转换为张量以及归一化操作,保持数据的一致性和标准化,以便于模型的评估。
data_transforms={
'train':
transforms.Compose([
transforms.Resize([300, 300]),
transforms.RandomRotation(45),
transforms.CenterCrop(256),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([
transforms.Resize([256, 256]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
同时,我们自定义了food_dataset
类来加载数据,该类继承自Dataset
类,实现了__init__
、__len__
和__getitem__
方法,能够从指定文件路径读取图像路径和标签,并对图像进行相应的转换处理,为后续的数据加载和训练提供了便利。
class food_dataset(Dataset):
def __init__(self,file_path,transform=None):
self.file_path=file_path
self.imgs=[]
self.labels=[]
self.transform=transform
with open(self.file_path) as f:
samples=[x.strip().split(' ') for x in f.readlines()]
for img_path,label in samples:
self.imgs.append(img_path)
self.labels.append(label)
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
image=Image.open(self.imgs[idx])
if self.transform:
image=self.transform(image)
label = self.labels[idx]
label = torch.from_numpy(np.array(label,dtype=np.int64))
return image,label
模型架构设计
本项目采用的CNN模型由多个卷积层、池化层和全连接层组成。在模型定义中,CNN
类继承自nn.Module
,通过__init__
方法构建网络结构。卷积层使用nn.Conv2d
定义,用于提取图像的特征;ReLU
激活函数为网络引入非线性,增强模型的表达能力;MaxPool2d
池化层用于下采样,减少参数数量和计算量,同时保留重要特征。
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.Conv2d(32, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU()
)
self.out = nn.Linear(64 * 64* 64, 20)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=self.conv3(x)
x=x.view(x.size(0),-1)
output=self.out(x)
return output
通过不断调整卷积层的通道数、卷积核大小、步长和填充等参数,以及增加网络层数,优化模型结构,以提高模型对食物图像的分类性能。
模型加载与测试
在训练完成后,我们得到了最优模型的参数文件best.pth
。通过以下代码加载模型参数,并将模型设置为评估模式,固定模型参数,防止在测试过程中参数更新。
model = CNN().to(device)
model.load_state_dict(torch.load('best.pth',map_location=torch.device('cpu')))
model.eval()
接下来,创建测试数据集和数据加载器,使用测试数据对模型进行评估。在测试过程中,我们定义了test_true
函数,该函数遍历测试数据加载器中的每一个批次数据,将数据传入模型进行预测,计算损失值和预测准确率,并记录预测结果和真实标签。
test_data=food_dataset(file_path='test.txt', transform=data_transforms['valid'])
test_dataloader=DataLoader(test_data,batch_size=1,shuffle=True)
def test_true(dataloader,model):
size=len(dataloader.dataset)
num_batches=len(dataloader)
model.eval()
test_loss,correct=0,0
with torch.no_grad():
for X,y in dataloader:
X,y=X.to(device),y.to(device)
pred=model.forward(X)
test_loss += loss_fn(pred, y).item()
correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
results.append(pred.argmax(1).item())
labels.append(y.item())
test_loss /= num_batches
correct /= size
print(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')
test_true(test_dataloader, model)
print('预测值:\t',results)
print('真实值:\t',labels)
最终,通过输出预测值和真实值,我们可以直观地对比模型的预测结果和实际标签,评估模型的分类效果。同时,打印出测试准确率和平均损失,量化模型在测试集上的性能表现。
总结与展望
通过上述步骤,我们成功地使用训练好的最优模型对食物图像测试数据进行了分类和评估。从测试结果来看,模型在测试集上达到了一定的准确率,但可能仍存在部分误分类的情况。未来,可以进一步优化模型架构,调整超参数,或者增加数据集的规模和多样性,以提高模型的性能和泛化能力。此外,将该模型部署到实际应用中,如开发手机APP或集成到智能设备中,能够为用户提供便捷的食物识别服务,具有广阔的应用前景。
希望本文的分享能够为大家在图像分类项目的实践中提供一些参考和帮助,共同推动计算机视觉技术在实际场景中的应用和发展。