Pokemon VS Digimon

原文章地址 教计算机对动漫进行分类 |by 泰瑞·史蒂文森 |中等 (medium.com),本文对此文章进行复刻,完成机器学习经典任务宝可梦和数码宝贝的识别

宝可梦数据集地址  Pokemon Images Dataset (kaggle.com)

数码宝贝数据集地址

Digimon-Generator-GAN/Digimon.zip at master · DeathReaper0965/Digimon-Generator-GAN (github.com)

代码部分

导入库

import os, os.path
import shutil
from fastai.vision.all import *
from fastai.text.all import *
from fastai.tabular.all import *
from fastai.collab import *

清点数据集数量

def getAmountOfFilesInDirectory(directory):
    return len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])

print('amount of Pokemon images', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/pokemon'))
print('amount of Digimon images', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/Digimon'))

移动训练集百分之30到测试集

def moveFilesToTestDirectory(directory, test_base_dir):
    """This function will move 30% of the training files to the testing directory"""
    amount_of_files = getAmountOfFilesInDirectory(directory)
    amount_of_files_to_move = int(amount_of_files * 0.3)  # 计算30%的文件数
    files_moved = 0  # 计数器,用于跟踪已移动的文件数
    class_name = os.path.basename(directory)  # 获取类别名称(如 Digimon 或 Pokemon)
    
    test_directory = os.path.join(test_base_dir, class_name)
    os.makedirs(test_directory, exist_ok=True)  # 创建测试目录(如果不存在)
    
    for file_name in os.listdir(directory):
        if files_moved >= amount_of_files_to_move:
            break  # 一旦移动的文件数达到30%,停止移动
        file_path = os.path.join(directory, file_name)
        if os.path.isfile(file_path):
            shutil.move(file_path, os.path.join(test_directory, file_name))  # 移动文件
            files_moved += 1

moveFilesToTestDirectory('./Pokemon VS Digimon/train/pokemon', './Pokemon VS Digimon/test')
moveFilesToTestDirectory('./Pokemon VS Digimon/train/Digimon', './Pokemon VS Digimon/test')

print('Amount of Pokemon training images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/pokemon'))
print('Amount of Digimon training images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/train/Digimon'))
print('Amount of Pokemon test images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/test/pokemon'))
print('Amount of Digimon test images:', getAmountOfFilesInDirectory('./Pokemon VS Digimon/test/Digimon'))

训练模型 

PATH = "./Pokemon VS Digimon" # Path to dataset
arch = resnet34
sz = 224  # 你希望的图像尺寸

# 数据块设置
dblock = DataBlock(
    blocks=(ImageBlock, CategoryBlock),  # 定义数据块类型
    get_items=get_image_files,  # 获取图像文件
    splitter=GrandparentSplitter(train_name='train', valid_name='test'),  # 划分训练集和验证集
    get_y=parent_label,  # 获取标签
    item_tfms=Resize(sz),  # 图像大小调整
    batch_tfms=aug_transforms(size=sz, max_zoom=1.1)  # 批量数据增强
)

# 创建 DataLoaders
path = Path(PATH)
dls = dblock.dataloaders(path)

# 创建并训练模型
learn = cnn_learner(dls, arch, metrics=[error_rate,accuracy])
learn.fine_tune(5)

保存模型并测试原有test数据集图片

learn.export('model.pkl')
from fastai.vision.all import *

# 加载保存的模型
learn = load_learner('model.pkl')

# 读取测试图像
test_image_path = './Pokemon VS Digimon/test/pokemon/11.jpg'
test_image = PILImage.create(test_image_path)

# 进行预测
pred_class, pred_idx, outputs = learn.predict(test_image)

# 显示预测结果
print(f'预测类别: {pred_class}')
print(f'预测输出概率向量: {outputs}')

# 可选:显示预测图像
test_image.show()

测试网络图片新增宝可梦,将webp格式转换成白色背景jpg

神奇宝贝百科,关于宝可梦的百科全书 (52poke.com)  宝可梦资源网

import cv2
import numpy as np

def change_background_to_white(image_path, output_path):
    # 读取图像
    image = cv2.imread(image_path)
    
    # 转换为灰度图像
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # 使用阈值分割图像
    _, mask = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)
    
    # 查找轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 创建白色背景
    white_background = np.ones_like(image) * 255
    
    # 遍历轮廓并填充
    for contour in contours:
        cv2.drawContours(white_background, [contour], -1, (0, 0, 0), -1)
    
    # 将前景与白色背景合并
    result = cv2.bitwise_or(image, white_background, mask=mask)
    
    # 保存结果图像
    cv2.imwrite(output_path, result)

# 输入和输出路径
input_jpg_path = './1.jpg'
output_jpg_path = './2.jpg'

# 更改背景为白色
change_background_to_white(input_jpg_path, output_jpg_path)

# 显示原始和修改后的图像
original_image = cv2.imread(input_jpg_path)
modified_image = cv2.imread(output_jpg_path)

cv2.imshow("Original Image", original_image)
cv2.imshow("Image with White Background", modified_image)
cv2.waitKey(0)
cv2.destroyAllWindows()


from fastai.vision.all import *

# 加载保存的模型
learn = load_learner('model.pkl')

# 定义图像转换,包括调整大小
transform = Resize(224)

# 读取和预处理测试图像
test_image_path = './3.jpg'
test_image = PILImage.create(test_image_path)

# 应用变换以调整图像大小
test_image = transform(test_image)

# 进行预测
pred_class, pred_idx, outputs = learn.predict(test_image)

# 显示预测结果
print(f'预测类别: {pred_class}')
print(f'预测输出概率向量: {outputs}')

# 可选:显示预测图像
test_image.show()

原图webp格式 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值