一、任务背景
本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字。
二、python建模
1、数据读取
首先,读取jpg数据文件,可以看到总共有15000张图像数据。
import pandas as pd
import os
path = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))
我们也可以打印一张图片出来看看。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
# 定义图片路径
image_path = path+files[3]
# 加载图片
image = mpimg.imread(image_path)
# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off') # 关闭坐标轴
plt.show()
2、数据集构建
加载必要的库以便后续使用,再定义一些超参数。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score
# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。
# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]
# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)
# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label