原文章地址 教计算机对动漫进行分类 |by 泰瑞·史蒂文森 |中等 (medium.com),本文对此文章进行复刻,完成机器学习经典任务宝可梦和数码宝贝的识别
宝可梦数据集地址 Pokemon Images Dataset (kaggle.com)
数码宝贝数据集地址
Digimon-Generator-GAN/Digimon.zip at master · DeathReaper0965/Digimon-Generator-GAN (github.com)
代码部分
导入库
import os, os.path
import shutil
from fastai.vision.all import *
from fastai.text.all import *
from fastai.tabular.all import *
from fastai.collab import *
清点数据集数量
def getAmountOfFilesInDirectory(directory):
return len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
print('amount of Pokemon images', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/pokemon'))
print('amount of Digimon images', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/Digimon'))
移动训练集百分之30到测试集
def moveFilesToTestDirectory(directory, test_base_dir):
"""This function will move 30% of the training files to the testing directory"""
amount_of_files = getAmountOfFilesInDirectory(directory)
amount_of_files_to_move = int(amount_of_files * 0.3) # 计算30%的文件数
files_moved = 0 # 计数器,用于跟踪已移动的文件数
class_name = os.path.basename(directory) # 获取类别名称(如 Digimon 或 Pokemon)
test_directory = os.path.join(test_base_dir, class_name)
os.makedirs(test_directory, exist_ok=True) # 创建测试目录(如果不存在)
for file_name in os.listdir(directory):
if files_moved >= amount_of_files_to_move:
break # 一旦移动的文件数达到30%,停止移动
file_path = os.path.join(directory, file_name)
if os.path.isfile(file_path):
shutil.move(file_path, os.path.join(test_directory, file_name)) # 移动文件
files_moved += 1
moveFilesToTestDirectory('./Pokemon VS Digimon/train/pokemon', './Pokemon VS Digimon/test')
moveFilesToTestDirectory('./Pokemon VS Digimon/train/Digimon', './Pokemon VS Digimon/test')
print('Amount of Pokemon training images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/pokemon'))
print('Amount of Digimon training images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/Digimon'))
print('Amount of Pokemon test images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/test/pokemon'))
print('Amount of Digimon test images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/test/Digimon'))
训练模型
PATH = "./Pokemon VS Digimon" # Path to dataset
arch = resnet34
sz = 224 # 你希望的图像尺寸
# 数据块设置
dblock = DataBlock(
blocks=(ImageBlock, CategoryBlock), # 定义数据块类型
get_items=get_image_files, # 获取图像文件
splitter=GrandparentSplitter(train_name='train', valid_name='test'), # 划分训练集和验证集
get_y=parent_label, # 获取标签
item_tfms=Resize(sz), # 图像大小调整
batch_tfms=aug_transforms(size=sz, max_zoom=1.1) # 批量数据增强
)
# 创建 DataLoaders
path = Path(PATH)
dls = dblock.dataloaders(path)
# 创建并训练模型
learn = cnn_learner(dls, arch, metrics=[error_rate,accuracy])
learn.fine_tune(5)
保存模型并测试原有test数据集图片
learn.export('model.pkl')
from fastai.vision.all import *
# 加载保存的模型
learn = load_learner('model.pkl')
# 读取测试图像
test_image_path = './Pokemon VS Digimon/test/pokemon/11.jpg'
test_image = PILImage.create(test_image_path)
# 进行预测
pred_class, pred_idx, outputs = learn.predict(test_image)
# 显示预测结果
print(f'预测类别: {pred_class}')
print(f'预测输出概率向量: {outputs}')
# 可选:显示预测图像
test_image.show()
测试网络图片新增宝可梦,将webp格式转换成白色背景jpg
神奇宝贝百科,关于宝可梦的百科全书 (52poke.com) 宝可梦资源网
import cv2
import numpy as np
def change_background_to_white(image_path, output_path):
# 读取图像
image = cv2.imread(image_path)
# 转换为灰度图像
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 使用阈值分割图像
_, mask = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建白色背景
white_background = np.ones_like(image) * 255
# 遍历轮廓并填充
for contour in contours:
cv2.drawContours(white_background, [contour], -1, (0, 0, 0), -1)
# 将前景与白色背景合并
result = cv2.bitwise_or(image, white_background, mask=mask)
# 保存结果图像
cv2.imwrite(output_path, result)
# 输入和输出路径
input_jpg_path = './1.jpg'
output_jpg_path = './2.jpg'
# 更改背景为白色
change_background_to_white(input_jpg_path, output_jpg_path)
# 显示原始和修改后的图像
original_image = cv2.imread(input_jpg_path)
modified_image = cv2.imread(output_jpg_path)
cv2.imshow("Original Image", original_image)
cv2.imshow("Image with White Background", modified_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
from fastai.vision.all import *
# 加载保存的模型
learn = load_learner('model.pkl')
# 定义图像转换,包括调整大小
transform = Resize(224)
# 读取和预处理测试图像
test_image_path = './3.jpg'
test_image = PILImage.create(test_image_path)
# 应用变换以调整图像大小
test_image = transform(test_image)
# 进行预测
pred_class, pred_idx, outputs = learn.predict(test_image)
# 显示预测结果
print(f'预测类别: {pred_class}')
print(f'预测输出概率向量: {outputs}')
# 可选:显示预测图像
test_image.show()
原图webp格式