项目描述
本作业的任务是迁移学习中的领域对抗性训练(Domain Adversarial Training)。
也就是左下角的那一块。
Domain Adaptation是让模型可以在训练时只需要 A dataset label,不需要 B dataset label 的情况下提高 B dataset 的准确率。 (A dataset & task 接近 B dataset & task)也就是给定真实图片 & 标签以及大量的手绘图片,请设计一种方法使得模型可以预测出手绘图片的标签是什么。
数据集介绍
这次的任务是源数据: 真实照片,目标数据: 手画涂鸦。
我们必须让model看过真实照片以及标签,尝试去预测手画涂鸦的标签为何。
资料位于’data/data58171/real_or_drawing.zip’
- Training : 5000 张真实图片 + label, 32 x 32 RGB
- Testing : 100000 张手绘图片,28 x 28 Gray Scale
- Label: 总共需要预测 10 个 class。
- 资料下载下来是以 0 ~ 9 作为label
项目要求
- 禁止手动标记label或在网上寻找label
- 禁止使用pre-trained model
1、流程
1.1、首先读取数据并生成txt文件(2:8),用于下一步的工作
1.2、其次使用数据EDA查看数据分布情况,及是否需要Shuffle(不需要)
1.3、而后针对测试集样本与训练集样本的特征差异,进行数据增强
1.4、主要采用OpenCV中的Canny方法进行边缘提取,并进行归一化、维度转换、恢复等操作
1.5、定义了一些简单随机数据增强(旋转、反转等),用于在模型训练中提高鲁棒性
1.6、而后建立数据集进行模型训练,将训练得到的权重文件保存,为测试集迁移学习所用
1.7、在模型训练中采用4种模型进行对比测试(训练集100轮,迁移学习后的测试集20轮)
1.8、由于notebook的提前停止存在一定问题,本次实验中使用固定轮次
1.9、将训练得到的模型文件加载,供测试集训练、测试、评估,模型评估采用测试集中的验证集进行测试(1000张)
import os
import paddle
import paddle.vision.transforms as T
import pandas as pd
import numpy as np
from PIL import Image
import paddle.nn.functional as F
import random
import cv2
import numpy as np
from matplotlib import pyplot as plt
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
2021-04-21 01:03:31,918 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-04-21 01:03:32,265 - INFO - generated new fontManager
!unzip -d work data/data75815/real_or_drawing.zip # 解压缩real_or_drawing数据集
数据读取
1、将训练集和测试集的filename和label进行提取,写入txt中
2、从训练集中,划分20%作为验证集,用于模型训练
#设置数据地址
data_path = f'./work/real_or_drawing/train_data'
character_folders = os.listdir(data_path)
#提取训练集
with open(f'./train_train.txt', 'w')as f_train:
with open(f'./train_val.txt', 'w')as f_test:
img_list = []
for character_folder in character_folders: # 循环文件夹列表
character_imgs = os.listdir(os.path.join(data_path,character_folder)) # 读取文件夹下面的内容
#初始化计数器
CNT = 0
for img in character_imgs: # 循环图片列表
img_list.append(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n') # 写入地址及标签
CNT += 1
print(character_folder,CNT) # 查看各个目录的图片数量
random.shuffle(img_list) # 打乱列表
#从训练集中划分20%作为验证集
#初始化计数器
CNT = 0
for img in img_list: # 循环列表
if CNT < int(len(img_list)*0.8): # 输出前80%为训练集
f_train.write(img)
CNT += 1
else: # 剩下来的为验证集
f_test.write(img)
CNT += 1
print(len(img_list),int(len(img_list)*0.8)) # 查看总数量和训练集数量
1 500
9 500
7 500
2 500
6 500
5 500
4 500
3 500
0 500
8 500
5000 4000
#提取训练集
with open(f'./test_train.txt', 'w')as f_train:
with open(f'./test_val.txt', 'w')as f_test:
img_list = []
for character_folder in character_folders: # 循环文件夹列表
character_imgs = os.listdir(os.path.join(data_path,character_folder)) # 读取文件夹下面的内容
#初始化计数器
CNT = 0
for img in character_imgs: # 循环图片列表
img_list.append(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n') # 写入地址及标签
CNT += 1
print(character_folder,CNT) # 查看各个目录的图片数量
random.shuffle(img_list) # 打乱列表
#从训练集中划分20%作为验证集
#初始化计数器
CNT = 0
for img in img_list: # 循环列表
if CNT < int(len(img_list)*0.8): # 输出前80%为训练集
f_train.write(img)
CNT += 1
else: # 剩下来的为验证集
f_test.write(img)
CNT += 1
print(len(img_list),int(len(img_list)*0.8)) # 查看总数量和训练集数量
1 500
9 500
7 500
2 500
6 500
5 500
4 500
3 500
0 500
8 500
5000 4000
# #提取测试集
# data_path1 = f'./work/real_or_drawing/test_data'
# character_folders = os.listdir(data_path1)
# with open(f'./test.txt', 'w')as f_train:
# img_list = []
# for character_folder in character_folders: # 循环文件夹列表
# character_imgs = os.listdir(os.path.join(data_path1,character_folder)) # 读取文件夹下面的内容
# #初始化计数器
# CNT = 0
# for img in character_imgs: # 循环图片列表
# img_list.append(os.path.join(data_path1,character_folder,img) + '\t' + character_folder + '\n') # 写入地址及标签
# CNT += 1
# print(character_folder,CNT) # 查看各个目录的图片数量
# random.shuffle(img_list) # 打乱列表
# for img in img_list:
# f_train.write(img)
# CNT += 1
# print(len(img_list))
数据EDA(Exploratory Data Analysis)
(数据在上一步进行了random)
tf="train_train.txt"
#读取txt文本
df = pd.read_table(tf,sep='\t',names=['name','label'])
print(df)
#然后选定label列下的数据,进行绘图
d = df['label'].hist().get_figure()
d.savefig("EDA.png")
train_image_path_list = df['name'].values
label_list = df['label'].values
label_list = paddle.to_tensor(label_list, dtype='int64')
train_label_list = paddle.nn.functional.one_hot(label_list, num_classes=10)
name label
0 ./work/real_or_drawing/train_data/3/1994.bmp 3
1 ./work/real_or_drawing/train_data/3/1799.bmp 3
2 ./work/real_or_drawing/train_data/8/4406.bmp 8
3 ./work/real_or_drawing/train_data/9/4970.bmp 9
4 ./work/real_or_drawing/train_data/4/2090.bmp 4
... ... ...
3995 ./work/real_or_drawing/train_data/5/2677.bmp 5
3996 ./work/real_or_drawing/train_data/8/4048.bmp 8
3997 ./work/real_or_drawing/train_data/8/4484.bmp 8
3998 ./work/real_or_drawing/train_data/2/1201.bmp 2
3999 ./work/real_or_drawing/train_data/5/2912.bmp 5
[4000 rows x 2 columns]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-keCBACCR-1618975504702)(output_11_2.png)]
#对验证集进行处理
#读取txt文本
vsf = pd.read_table('./train_val.txt',sep='\t',names=['name','label'])
print(vsf)
val_image_list=vsf
val_image_path_list = val_image_list['name'].values
val_label_list = val_image_list['label'].values
val_label_list = paddle.to_tensor(val_label_list, dtype='int64')
val_label_list = paddle.nn.functional.one_hot(val_label_list, num_classes=11)
name label
0 ./work/real_or_drawing/train_data/2/1351.bmp 2
1 ./work/real_or_drawing/train_data/1/714.bmp 1
2 ./work/real_or_drawing/train_data/9/4549.bmp 9
3 ./work/real_or_drawing/train_data/9/4588.bmp 9
4 ./work/real_or_drawing/train_data/4/2105.bmp 4
.. ... ...
995 ./work/real_or_drawing/train_data/4/2016.bmp 4
996 ./work/real_or_drawing/train_data/8/4231.bmp 8
997 ./work/real_or_drawing/train_data/2/1486.bmp 2
998 ./work/real_or_drawing/train_data/3/1685.bmp 3
999 ./work/real_or_drawing/train_data/2/1247.bmp 2
[1000 rows x 2 columns]
print(train_label_list[1])
print(train_label_list[2])
print(train_label_list[3])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])
Data Shuffle
从EDA图中可知该步骤可省略
EDA结果
数据增强
由下图可知,测试集是线条进行组成,和训练集差异较大,所以我们需要对训练集进行一定处理,以便模型能学习到更多的特征。
拟采用阈值分割+边缘提取算法作为基础,进行数据处理
#验证
ts_img = cv2.imread('work/real_or_drawing/train_data/3/1518.bmp')
ts_img1 = cv2.cvtColor(ts_img, cv2.COLOR_BGR2RGB)
plt.imshow(ts_img1)
print(ts_img1.shape)
ts_img3 = ts_img1.transpose((2, 0, 1))
print(ts_img3.shape)
(32, 32, 3)
(3, 32, 32)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-v21LpsCC-1618975504709)(output_16_1.png)]
img = Image.open('work/real_or_drawing/train_data/3/1518.bmp') # 读取图片
# ts_img2=cv2.cvtColor(ts_img, cv2.COLOR_BGR2GRAY)
ts_img2=cv2.cvtColor(ts_img, cv2.COLOR_BGR2RGB)
# canny边缘检测
ts_img3 = cv2.Canny(ts_img2, 100, 150)
print(ts_img3.shape)
plt.imshow(ts_img3)
(32, 32)
<matplotlib.image.AxesImage at 0x7face0049050>
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
'a.item() instead', DeprecationWarning, stacklevel=1)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VOJeNHfC-1618975504709)(output_17_3.png)]
#转回3,32,32
ts_img4=cv2.cvtColor(ts_img3, cv2.COLOR_BGR2RGB)
print(ts_img4.shape)
print(ts_img4.dtype)
plt.imshow(ts_img4)
(32, 32, 3)
uint8
<matplotlib.image.AxesImage at 0x7fad8afcac90>
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lqHBDll6-1618975504710)(output_18_2.png)]
#定义一个函数对图像数据进行处理
def MyImgPro(img):
img = cv2.Canny(np.asarray(img), 100, 150)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(28,28))
return img
#测试
img4 = cv2.imread('work/real_or_drawing/train_data/4/2012.bmp')
ff=MyImgPro(img4)
plt.imshow(ff)
# gg= np.float32(ff)
# plt.imshow(gg)
<matplotlib.image.AxesImage at 0x7face0041590>
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zJQgyc