【百度飞桨Paddle】迁移学习项目分享【acc:0.79】

项目描述

本作业的任务是迁移学习中的领域对抗性训练(Domain Adversarial Training)。

也就是左下角的那一块。

Domain Adaptation是让模型可以在训练时只需要 A dataset label,不需要 B dataset label 的情况下提高 B dataset 的准确率。 (A dataset & task 接近 B dataset & task)也就是给定真实图片 & 标签以及大量的手绘图片,请设计一种方法使得模型可以预测出手绘图片的标签是什么。

数据集介绍

这次的任务是源数据: 真实照片,目标数据: 手画涂鸦。

我们必须让model看过真实照片以及标签,尝试去预测手画涂鸦的标签为何。

资料位于’data/data58171/real_or_drawing.zip’

  • Training : 5000 张真实图片 + label, 32 x 32 RGB
  • Testing : 100000 张手绘图片,28 x 28 Gray Scale
  • Label: 总共需要预测 10 个 class。
  • 资料下载下来是以 0 ~ 9 作为label

数据集下载

项目要求

  • 禁止手动标记label或在网上寻找label
  • 禁止使用pre-trained model

1、流程

1.1、首先读取数据并生成txt文件(2:8),用于下一步的工作
1.2、其次使用数据EDA查看数据分布情况,及是否需要Shuffle(不需要)
1.3、而后针对测试集样本与训练集样本的特征差异,进行数据增强
1.4、主要采用OpenCV中的Canny方法进行边缘提取,并进行归一化、维度转换、恢复等操作
1.5、定义了一些简单随机数据增强(旋转、反转等),用于在模型训练中提高鲁棒性
1.6、而后建立数据集进行模型训练,将训练得到的权重文件保存,为测试集迁移学习所用
1.7、在模型训练中采用4种模型进行对比测试(训练集100轮,迁移学习后的测试集20轮)
1.8、由于notebook的提前停止存在一定问题,本次实验中使用固定轮次
1.9、将训练得到的模型文件加载,供测试集训练、测试、评估,模型评估采用测试集中的验证集进行测试(1000张)

import os
import paddle
import paddle.vision.transforms as T
import pandas as pd
import numpy as np
from PIL import Image
import paddle.nn.functional as F
import random
import cv2
import numpy as np
from matplotlib import pyplot as plt
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
2021-04-21 01:03:31,918 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-04-21 01:03:32,265 - INFO - generated new fontManager
!unzip -d work data/data75815/real_or_drawing.zip # 解压缩real_or_drawing数据集

数据读取

1、将训练集和测试集的filename和label进行提取,写入txt中
2、从训练集中,划分20%作为验证集,用于模型训练

#设置数据地址
data_path = f'./work/real_or_drawing/train_data' 
character_folders = os.listdir(data_path) 
#提取训练集
with open(f'./train_train.txt', 'w')as f_train:
    with open(f'./train_val.txt', 'w')as f_test:
        img_list = []
        for character_folder in character_folders:  #  循环文件夹列表  
            character_imgs = os.listdir(os.path.join(data_path,character_folder))  # 读取文件夹下面的内容
            #初始化计数器
            CNT = 0
            for img in character_imgs:  # 循环图片列表
                img_list.append(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n')  # 写入地址及标签
                CNT += 1
            print(character_folder,CNT)  # 查看各个目录的图片数量
        random.shuffle(img_list)  # 打乱列表

        #从训练集中划分20%作为验证集
        #初始化计数器
        CNT = 0
        for img in img_list:  # 循环列表
            if CNT < int(len(img_list)*0.8):  # 输出前80%为训练集
                f_train.write(img)
                CNT  += 1
            else:  # 剩下来的为验证集
                f_test.write(img)
                CNT += 1
        print(len(img_list),int(len(img_list)*0.8))  # 查看总数量和训练集数量

1 500
9 500
7 500
2 500
6 500
5 500
4 500
3 500
0 500
8 500
5000 4000
#提取训练集
with open(f'./test_train.txt', 'w')as f_train:
    with open(f'./test_val.txt', 'w')as f_test:
        img_list = []
        for character_folder in character_folders:  #  循环文件夹列表  
            character_imgs = os.listdir(os.path.join(data_path,character_folder))  # 读取文件夹下面的内容
            #初始化计数器
            CNT = 0
            for img in character_imgs:  # 循环图片列表
                img_list.append(os.path.join(data_path,character_folder,img) + '\t' + character_folder + '\n')  # 写入地址及标签
                CNT += 1
            print(character_folder,CNT)  # 查看各个目录的图片数量
        random.shuffle(img_list)  # 打乱列表

        #从训练集中划分20%作为验证集
        #初始化计数器
        CNT = 0
        for img in img_list:  # 循环列表
            if CNT < int(len(img_list)*0.8):  # 输出前80%为训练集
                f_train.write(img)
                CNT  += 1
            else:  # 剩下来的为验证集
                f_test.write(img)
                CNT += 1
        print(len(img_list),int(len(img_list)*0.8))  # 查看总数量和训练集数量

1 500
9 500
7 500
2 500
6 500
5 500
4 500
3 500
0 500
8 500
5000 4000
# #提取测试集
# data_path1 = f'./work/real_or_drawing/test_data' 
# character_folders = os.listdir(data_path1) 
# with open(f'./test.txt', 'w')as f_train:
#     img_list = []
#     for character_folder in character_folders:  #  循环文件夹列表  
#         character_imgs = os.listdir(os.path.join(data_path1,character_folder))  # 读取文件夹下面的内容
#         #初始化计数器
#         CNT = 0
#         for img in character_imgs:  # 循环图片列表
#             img_list.append(os.path.join(data_path1,character_folder,img) + '\t' + character_folder + '\n')  # 写入地址及标签
#             CNT += 1
#         print(character_folder,CNT)  # 查看各个目录的图片数量
#     random.shuffle(img_list)  # 打乱列表
#     for img in img_list:
#         f_train.write(img)
#         CNT += 1
#     print(len(img_list))

数据EDA(Exploratory Data Analysis)

(数据在上一步进行了random)

tf="train_train.txt"
#读取txt文本
df =  pd.read_table(tf,sep='\t',names=['name','label'])
print(df)
#然后选定label列下的数据,进行绘图
d = df['label'].hist().get_figure()
d.savefig("EDA.png")

train_image_path_list = df['name'].values
label_list = df['label'].values
label_list = paddle.to_tensor(label_list, dtype='int64')
train_label_list = paddle.nn.functional.one_hot(label_list, num_classes=10)

                                              name  label
0     ./work/real_or_drawing/train_data/3/1994.bmp      3
1     ./work/real_or_drawing/train_data/3/1799.bmp      3
2     ./work/real_or_drawing/train_data/8/4406.bmp      8
3     ./work/real_or_drawing/train_data/9/4970.bmp      9
4     ./work/real_or_drawing/train_data/4/2090.bmp      4
...                                            ...    ...
3995  ./work/real_or_drawing/train_data/5/2677.bmp      5
3996  ./work/real_or_drawing/train_data/8/4048.bmp      8
3997  ./work/real_or_drawing/train_data/8/4484.bmp      8
3998  ./work/real_or_drawing/train_data/2/1201.bmp      2
3999  ./work/real_or_drawing/train_data/5/2912.bmp      5

[4000 rows x 2 columns]


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-keCBACCR-1618975504702)(output_11_2.png)]

#对验证集进行处理
#读取txt文本
vsf =  pd.read_table('./train_val.txt',sep='\t',names=['name','label'])
print(vsf)
val_image_list=vsf

val_image_path_list = val_image_list['name'].values
val_label_list = val_image_list['label'].values
val_label_list = paddle.to_tensor(val_label_list, dtype='int64')
val_label_list = paddle.nn.functional.one_hot(val_label_list, num_classes=11)
                                             name  label
0    ./work/real_or_drawing/train_data/2/1351.bmp      2
1     ./work/real_or_drawing/train_data/1/714.bmp      1
2    ./work/real_or_drawing/train_data/9/4549.bmp      9
3    ./work/real_or_drawing/train_data/9/4588.bmp      9
4    ./work/real_or_drawing/train_data/4/2105.bmp      4
..                                            ...    ...
995  ./work/real_or_drawing/train_data/4/2016.bmp      4
996  ./work/real_or_drawing/train_data/8/4231.bmp      8
997  ./work/real_or_drawing/train_data/2/1486.bmp      2
998  ./work/real_or_drawing/train_data/3/1685.bmp      3
999  ./work/real_or_drawing/train_data/2/1247.bmp      2

[1000 rows x 2 columns]
print(train_label_list[1])
print(train_label_list[2])
print(train_label_list[3])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

Data Shuffle

从EDA图中可知该步骤可省略

EDA结果

EDA

数据增强

由下图可知,测试集是线条进行组成,和训练集差异较大,所以我们需要对训练集进行一定处理,以便模型能学习到更多的特征。
拟采用阈值分割+边缘提取算法作为基础,进行数据处理


#验证
ts_img = cv2.imread('work/real_or_drawing/train_data/3/1518.bmp')
ts_img1 = cv2.cvtColor(ts_img, cv2.COLOR_BGR2RGB)
plt.imshow(ts_img1)
print(ts_img1.shape)
ts_img3 = ts_img1.transpose((2, 0, 1))
print(ts_img3.shape)

(32, 32, 3)
(3, 32, 32)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-v21LpsCC-1618975504709)(output_16_1.png)]

img = Image.open('work/real_or_drawing/train_data/3/1518.bmp')  # 读取图片
# ts_img2=cv2.cvtColor(ts_img, cv2.COLOR_BGR2GRAY)
ts_img2=cv2.cvtColor(ts_img, cv2.COLOR_BGR2RGB)
# canny边缘检测
ts_img3 = cv2.Canny(ts_img2, 100, 150) 

print(ts_img3.shape)
plt.imshow(ts_img3)
(32, 32)





<matplotlib.image.AxesImage at 0x7face0049050>



/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VOJeNHfC-1618975504709)(output_17_3.png)]

#转回3,32,32
ts_img4=cv2.cvtColor(ts_img3, cv2.COLOR_BGR2RGB)
print(ts_img4.shape)
print(ts_img4.dtype)
plt.imshow(ts_img4)
(32, 32, 3)
uint8





<matplotlib.image.AxesImage at 0x7fad8afcac90>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lqHBDll6-1618975504710)(output_18_2.png)]

#定义一个函数对图像数据进行处理
def MyImgPro(img):
    img = cv2.Canny(np.asarray(img), 100, 150) 
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(28,28))
    return img   

#测试
img4 = cv2.imread('work/real_or_drawing/train_data/4/2012.bmp')
ff=MyImgPro(img4)
plt.imshow(ff)
# gg= np.float32(ff)
# plt.imshow(gg)
<matplotlib.image.AxesImage at 0x7face0041590>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zJQgyc

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值