{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Task02 - 数据读取与数据扩增"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch.nn.functional as F\n",
"import torch.utils.data as ud\n",
"from torchvision import transforms\n",
"from torchvision import datasets\n",
"from PIL import Image\n",
"import cv2\n",
"import glob\n",
"import json"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型数据准备,数据扩增"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# 进行数据扩展\n",
"\n",
"class SVHDataset(ud.Dataset):\n",
" def __init__(self, img_pattern, label_folder, transform=None):\n",
" self.img_path = glob.glob(img_pattern)\n",
" self.img_label = [v['label'] for k,v in json.load(open(label_folder)).items()]\n",
" self.img_path.sort()\n",
" self.transform = transform\n",
" def __getitem__(self, index):\n",
" \"\"\"\n",
" 实现了切片方法的获取\n",
" \"\"\"\n",
" # 批量读取数据\n",
" img = Image.open(self.img_path[index]).convert('RGB')\n",
" if self.transform is not None:\n",
" img = self.transform(img)\n",
" # 将原始数据分类10为0, 保证有五位数字\n",
" # example: [10]*2 = [10, 10], [2, 3] + [10] = [2, 3, 10]\n",
" lbl = np.array(self.img_label[index], dtype=np.int)\n",
" lbl = list(lbl) + (5 - len(lbl))*[10]\n",
" return img, torch.Tensor(lbl[:5])\n",
" def __len__(self):\n",
" return len(self.img_path)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# 数据扩充和训练规范化\n",
"data_transforms = {\n",
" 'train': transforms.Compose([\n",
" # 缩放到固定尺⼨\n",
" transforms.Resize((64, 128)),\n",
" transforms.RandomCrop((60, 120)),\n",
" transforms.ColorJitter(0.3, 0.3, 0.2),\n",
" # 加⼊随机旋转\n",
" transforms.RandomRotation(10),\n",
" # 将图⽚转换为pytorch 的tesntor\n",
" transforms.ToTensor(),\n",
" # 对图像像素进⾏归⼀化\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
" ]),\n",
" 'val': transforms.Compose([\n",
" # 缩放到固定尺⼨\n",
" transforms.Resize((60, 128)),\n",
" # 将图⽚转换为pytorch 的tesntor\n",
" transforms.ToTensor(),\n",
" # 对图像像素进⾏归⼀化\n",
" transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
" ]),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAT0AAACPCAYAAACI7gxXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO19aZAlV3Xm972tXq29aHNro8GWZWmwFmCEMNgWCNmyxiD/sD1oxljMyGiYwAHM4DACxh4wXjA2WCZmvChYxAhCGCSMZA0M1sj0ODAOoCVLWEJoYZBQo1Z3a+nu6lrfcvwjsyrPPfny1qvXVe+VneeLqKjMvJn3nrx58757vjwLRQQOh8NRFlRGLYDD4XAMEz7pORyOUsEnPYfDUSr4pOdwOEoFn/QcDkep4JOew+EoFXzSi4DkHpK/stHXknwXyY8cn3RbCyTPJblX7QvJOZK/M0q5HBsLkq8meYxkl+Sr02MfIvmmUcvWL0ox6ZF8bOUBbQWIyO+KyLon0612HwbvA/CH5tj5IvLuXieT3J1OjMfU329shCAkX0jySySfJrmhhqgk30OyZeR+wYB1/SLJr5KcJ7mnR/kFJO9Oy+8mecGA7TRI3pKOHyF5iSknyd8n+Uz69wGS7FWXiPxfEZkC8D11+A8AvJtkYxD5ho1STHqOzQPJGsldAF4J4PMDVLFdRKbSv/dtkFgtAJ8BcM0G1WfxF0rmKRH5/wPW8yyA6wG83xakE8htAD4JYAeATwC47Tgmlq8A+CUAT/UouxbAzwE4H8B5AH4WwH/qt2IR2Q/g2wBeO6BsQ0WpJz2SO0jeQfIQyefS7dPNaT9I8uskj5C8jeROdf3F6S/1YZL32V/QSLvvIfnJdLtJ8pPpL+xhkt8geUqPa24CcCaAv0pXF7++lgypiv0+kn9HcpbkX5M8ca12SZ5K8naSz5J8lOQbjey3pNceBfAGAJcBuEdEFvu5/82GiDwkIh8F8MCoZYkhXTV9BsCTPYovAVADcL2ILInIhwEQwKsGaGdZRK4Xka8A6PQ45WoAHxSRfSLyfQAfRPJc14M9AP7NemUbBUo96SG5/48DeB6SCWUBwP8w5/wygP8I4FQAbQAfBgCSpwH43wB+G8BOAL8G4FaSJ61ThqsBbANwBoATALwplSOAiLweiUrxmnR18YE+Zfh3AP4DgJMBNNJz1mr3ZgD70nv+eQC/S/JSVeeVAG4BsB3ApwD8KICH1nnfK3ic5D6SH1+ZkP8Z4DXpD8IDJP/zJrXxrwB8U0I/0W+mxzejrfvU/n0DtPMgkpXilkepJz0ReUZEbhWReRGZBfA7AH7SnHaTiNwvInMAfgPAL5KsIlEVviAiXxCRrojcCWAvgCvWKUYLyaTzQyLSEZG7ReRon9f2I8PHReRhEVlAovKt8EI92yV5BoBXAHiHiCyKyL0APgLg9arOvxeRz6dtLiCZ/GbXed9PA/jXSH5wXgxgGskEutXxGQDnADgJwBsB/CbJqzahnSkAR8yxI0j6abPbOgJgqojXK8AsknGw5VHqSY/kBMk/J/l4qqr9LYDt6aS2gifU9uMA6gBORPKy/kKqGh4meRjJZLFrnWLcBOBLAD5N8smURK73eW0/MmgOZx7JAI+1eyqAZ9MfgRU8DuA0ta/7BACewzpfRhE5JiJ7RaQtIgcA/CqAnyI5s556hg0R+ZaIPJn+UHwVwB8jWQ1vNI4BsH0xg/X/uAzS1gyAY2aVuRamARzeUKk2CaWe9AC8HcDZAF4qIjMAfiI9rn/hzlDbZyJZIT2N5MW/SUS2q79JEcmR0jGISEtE3isi5wL4MSQk8i8XnW72B5Yh0u6TAHaS1JPYmQC+H5HjmwB+eK021xIp/b+e1cVWgGBzZH4AwHlmtXUeNoen