新年美食鉴赏——基于注意力机制CBAM的美食101分类
春节是一年当中最隆重、最盛大的节日,吃得最好的一餐可能就是年夜饭了,想想就很激动呢!随着新节钟声的敲响,一年一度的美食斗图大赛即将上演,快来看看都有啥好吃的!
一、数据预处理
本项目使用的数据集地址:101类美食数据集
# 解压数据集
!unzip -oq /home/aistudio/data/data70204/images.zip -d food
1.数据集介绍
该数据集包含完整的101种食物。为图像分析提供比CIFAR10或MNIST更令人兴奋的简单训练集,因此,数据中包含大量缩小的图像版本,以便进行快速测试。
101种类别如下所示(索引从0开始):
‘apple_pie’: 0, ‘baby_back_ribs’: 1, ‘baklava’: 2, ‘beef_carpaccio’: 3, ‘beef_tartare’: 4, ‘beet_salad’: 5, ‘beignets’: 6, ‘bibimbap’: 7, ‘bread_pudding’: 8, ‘breakfast_burrito’: 9, ‘bruschetta’: 10,
‘caesar_salad’: 11, ‘cannoli’: 12, ‘caprese_salad’: 13, ‘carrot_cake’: 14, ‘ceviche’: 15, ‘cheesecake’: 16, ‘cheese_plate’: 17, ‘chicken_curry’: 18, ‘chicken_quesadilla’: 19, ‘chicken_wings’: 20,
‘chocolate_cake’: 21, ‘chocolate_mousse’: 22, ‘churros’: 23, ‘clam_chowder’: 24, ‘club_sandwich’: 25, ‘crab_cakes’: 26, ‘creme_brulee’: 27, ‘croque_madame’: 28, ‘cup_cakes’: 29, ‘deviled_eggs’: 30,
‘donuts’: 31, ‘dumplings’: 32, ‘edamame’: 33, ‘eggs_benedict’: 34, ‘escargots’: 35, ‘falafel’: 36, ‘filet_mignon’: 37, ‘fish_and_chips’: 38, ‘foie_gras’: 39, ‘french_fries’: 40,
‘french_onion_soup’: 41, ‘french_toast’: 42, ‘fried_calamari’: 43, ‘fried_rice’: 44, ‘frozen_yogurt’: 45, ‘garlic_bread’: 46, ‘gnocchi’: 47, ‘greek_salad’: 48, ‘grilled_cheese_sandwich’: 49, ‘grilled_salmon’: 50,
‘guacamole’: 51, ‘gyoza’: 52, ‘hamburger’: 53, ‘hot_and_sour_soup’: 54, ‘hot_dog’: 55, ‘huevos_rancheros’: 56, ‘hummus’: 57, ‘ice_cream’: 58, ‘lasagna’: 59, ‘lobster_bisque’: 60,
‘lobster_roll_sandwich’: 61, ‘macaroni_and_cheese’: 62, ‘macarons’: 63, ‘miso_soup’: 64, ‘mussels’: 65, ‘nachos’: 66, ‘omelette’: 67, ‘onion_rings’: 68, ‘oysters’: 69, ‘pad_thai’: 70,
‘paella’: 71, ‘pancakes’: 72, ‘panna_cotta’: 73, ‘peking_duck’: 74, ‘pho’: 75, ‘pizza’: 76, ‘pork_chop’: 77, ‘poutine’: 78, ‘prime_rib’: 79, ‘pulled_pork_sandwich’: 80,
‘ramen’: 81, ‘ravioli’: 82, ‘red_velvet_cake’: 83, ‘risotto’: 84, ‘samosa’: 85, ‘sashimi’: 86, ‘scallops’: 87, ‘seaweed_salad’: 88, ‘shrimp_and_grits’: 89, ‘spaghetti_bolognese’: 90,
‘spaghetti_carbonara’: 91, ‘spring_rolls’: 92, ‘steak’: 93, ‘strawberry_shortcake’: 94, ‘sushi’: 95, ‘tacos’: 96, ‘takoyaki’: 97, ‘tiramisu’: 98, ‘tuna_tartare’: 99, ‘waffles’: 100}
2.读取标签
在做分类任务之前,要明确有几类,因为机器只认识二进制,因此要把每一类(字符串)映射到唯一的一个数字上
txtpath = r"classes.txt"
fp = open(txtpath)
arr = []
for lines in fp.readlines():
# print(lines)
lines = lines.replace("\n","")
arr.append(lines)
# print(arr)
fp.close()
number = []
for item in range(len(arr)):
number.append(item)
categorys = dict(zip(arr, number))
print(categorys)
{'apple_pie': 0, 'baby_back_ribs': 1, 'baklava': 2, 'beef_carpaccio': 3, 'beef_tartare': 4, 'beet_salad': 5, 'beignets': 6, 'bibimbap': 7, 'bread_pudding': 8, 'breakfast_burrito': 9, 'bruschetta': 10, 'caesar_salad': 11, 'cannoli': 12, 'caprese_salad': 13, 'carrot_cake': 14, 'ceviche': 15, 'cheesecake': 16, 'cheese_plate': 17, 'chicken_curry': 18, 'chicken_quesadilla': 19, 'chicken_wings': 20, 'chocolate_cake': 21, 'chocolate_mousse': 22, 'churros': 23, 'clam_chowder': 24, 'club_sandwich': 25, 'crab_cakes': 26, 'creme_brulee': 27, 'croque_madame': 28, 'cup_cakes': 29, 'deviled_eggs': 30, 'donuts': 31, 'dumplings': 32, 'edamame': 33, 'eggs_benedict': 34, 'escargots': 35, 'falafel': 36, 'filet_mignon': 37, 'fish_and_chips': 38, 'foie_gras': 39, 'french_fries': 40, 'french_onion_soup': 41, 'french_toast': 42, 'fried_calamari': 43, 'fried_rice': 44, 'frozen_yogurt': 45, 'garlic_bread': 46, 'gnocchi': 47, 'greek_salad': 48, 'grilled_cheese_sandwich': 49, 'grilled_salmon': 50, 'guacamole': 51, 'gyoza': 52, 'hamburger': 53, 'hot_and_sour_soup': 54, 'hot_dog': 55, 'huevos_rancheros': 56, 'hummus': 57, 'ice_cream': 58, 'lasagna': 59, 'lobster_bisque': 60, 'lobster_roll_sandwich': 61, 'macaroni_and_cheese': 62, 'macarons': 63, 'miso_soup': 64, 'mussels': 65, 'nachos': 66, 'omelette': 67, 'onion_rings': 68, 'oysters': 69, 'pad_thai': 70, 'paella': 71, 'pancakes': 72, 'panna_cotta': 73, 'peking_duck': 74, 'pho': 75, 'pizza': 76, 'pork_chop': 77, 'poutine': 78, 'prime_rib': 79, 'pulled_pork_sandwich': 80, 'ramen': 81, 'ravioli': 82, 'red_velvet_cake': 83, 'risotto': 84, 'samosa': 85, 'sashimi': 86, 'scallops': 87, 'seaweed_salad': 88, 'shrimp_and_grits': 89, 'spaghetti_bolognese': 90, 'spaghetti_carbonara': 91, 'spring_rolls': 92, 'steak': 93, 'strawberry_shortcake': 94, 'sushi': 95, 'tacos': 96, 'takoyaki': 97, 'tiramisu': 98, 'tuna_tartare': 99, 'waffles': 100}
3.统一命名
统一命名,方便检查数据集
# 将图片整理到一个文件夹,并统一命名
import os
from PIL import Image
categorys = arr
if not os.path.exists("temporary"):
os.mkdir("temporary")
for category in categorys:
# 图片文件夹路径
path = r"food/{}/".format(category)
count = 0
for filename in os.listdir(path):
img = Image.open(path + filename)
img = img.resize((512, 512),Image.ANTIALIAS) # 转换图片,图像尺寸变为1280*720
img = img.convert('RGB') # 保存为.jpg格式才需要
img.save(r"temporary/{}{}.jpg".format(category, str(count)))
count += 1
4.整理图片路径
整理图片路径,便于将图片送入神经网络
# 获取图片路径与图片标签
import os
import string
train_list = open('train_list.txt',mode='w')
paths = r'temporary/'
# 返回指定路径的文件夹名称
dirs = os.listdir(paths)
# 循环遍历该目录下的照片
for path in dirs:
# 拼接字符串
imgPath = paths + path
train_list.write(imgPath + '\t')
for category in categorys:
if category == path.replace(".jpg","").rstrip(string.digits):
train_list.write(str(categorys[category]) + '\n')
train_list.close()
5.划分训练集与验证集
验证集用于检验模型是否过拟合,这里的划分标准是5:1,即每5张图片取1张做验证数据
# 划分训练集和验证集
import shutil
train_dir = '/home/aistudio/work/trainImages'
eval_dir = '/home/aistudio/work/evalImages'
train_list_path = '/home/aistudio/train_list.txt'
target_path = "/home/aistudio/"
if not os.path.exists(train_dir):
os.mkdir(train_dir)
if not os.path.exists(eval_dir):
os.mkdir(eval_dir)
with open(train_list_path, 'r') as f:
data = f.readlines()
for i in range(len(data)):
img_path = data[i].split('\t')[0]
class_label = data[i].split('\t')[1][:-1]
if i % 5 == 0: # 每5张图片取一个做验证数据
eval_target_dir = os.path.join(eval_dir, str(class_label))
eval_img_path = os.path.join(target_path, img_path)
if not os.path.exists(eval_target_dir):
os.mkdir(eval_target_dir)
shutil.copy(eval_img_path, eval_target_dir)
else:
train_target_dir = os.path.join(train_dir, str(class_label))
train_img_path = os.path.join(target_path, img_path)
if not os.path.exists(train_target_dir):
os.mkdir(train_target_dir)
shutil.copy(train_img_path, train_target_dir)
print ('划分训练集和验证集完成!')
划分训练集和验证集完成!
6.定义美食数据集
分类任务中有一个非常重要的点就是归一化处理,通过归一化处理,将图片的取值范围从0~255转化为0~1之间,这个对于后续的神经网络有很大的好处,如果不做归一化,那么神经网络有可能学不到任何东西,输出结果全部一样。
import os
import numpy as np
import paddle
from paddle.io import Dataset
from paddle.vision.datasets import DatasetFolder, ImageFolder
from paddle.vision.transforms import Compose, Resize, BrightnessTransform, ColorJitter, Normalize, Transpose
class FoodsDataset(Dataset):
"""
步骤一:继承paddle.io.Dataset类
"""
def __init__(self, mode='train'):
"""
步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
"""
super(FoodsDataset, self).__init__()
train_image_dir = '/home/aistudio/work/trainImages'
eval_image_dir = '/home/aistudio/work/evalImages'
test_image_dir = '/home/aistudio/work/evalImages'
transform_train = Compose([Normalize(mean=[127.5, 127.5, 127.5],std=[127.5, 127.5, 127.5],data_format='HWC'), Transpose()])
transform_eval = Compose([Normalize(mean=[127.5, 127.5, 127.5],std=[127.5, 127.5, 127.5],data_format='HWC'), Transpose()])
train_data_folder = DatasetFolder(train_image_dir, transform=transform_train)
eval_data_folder = DatasetFolder(eval_image_dir, transform=transform_eval)
test_data_folder = DatasetFolder(test_image_dir)
self.mode = mode
if self.mode == 'train':
self.data =