Paddle实现迁移学习

项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析
课程 传送门
该项目AiStudio项目 传送门
数据集 传送门

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

迁移学习

三岁出品,必是精品!

迁移学习:先训练一个模型然后把该模型的参数给类似的项目直接进行训练,效果好极了!
基于这个原理我们开始思考流程


1、数据处理
2、网络定义
3、训练模型
4、固定参数
5、加载模型参数
6、进行迁移学习

本项目的最后结果差的离谱,原因原始模型就不好。
提高分数建议:
1、修改原始网络
2、提高原始模型质量
3、修改训练轮数
4、数据处理更加精准

作业5-迁移学习

项目描述

本作业的任务是迁移学习中的领域对抗性训练(Domain Adversarial Training)。

也就是左下角的那一块。

Domain Adaptation是让模型可以在训练时只需要 A dataset label,不需要 B dataset label 的情况下提高 B dataset 的准确率。 (A dataset & task 接近 B dataset & task)也就是给定真实图片 & 标签以及大量的手绘图片,请设计一种方法使得模型可以预测出手绘图片的标签是什么。

数据集介绍

这次的任务是源数据: 真实照片,目标数据: 手画涂鸦。

我们必须让model看过真实照片以及标签,尝试去预测手画涂鸦的标签为何。

资料位于’data/data58171/real_or_drawing.zip’

  • Training : 5000 张真实图片 + label, 32 x 32 RGB
  • Testing : 100000 张手绘图片,28 x 28 Gray Scale
  • Label: 总共需要预测 10 个 class。
  • 资料下载下来是以 0 ~ 9 作为label

特别注意一点: **这次的源数据和目标数据的图片都是平衡的,你们可以使用这个资料做其他事情。 **

项目要求

  • 禁止手动标记label或在网上寻找label
  • 禁止使用pre-trained model

数据准备

!unzip -oq /home/aistudio/data/data75815/real_or_drawing.zip
import os
import paddle
import paddle.vision.transforms as T
import numpy as np
from PIL import Image
import paddle.nn.functional as F
import random

数据处理

data_path = '/home/aistudio/real_or_drawing/train_data'  # 设置初始文件地址
character_folders = os.listdir(data_path)  # 查看地址下文件夹
character_folders
['2', '1', '5', '3', '8', '7', '6', '9', '4', '0']
# 新建标签列表
def img_list_text(train='train'):
    data_path = f'/home/aistudio/real_or_drawing/{train}_data'  # 设置初始文件地址
    character_folders = os.listdir(data_path)  # 查看地址下文件夹

    if(os.path.exists(f'./{train}_train_imglist.txt')):  # 判断文件是否存在
        os.remove(f'./{train}_train_imglist.txt')  # 删除文件
    if(os.path.exists(f'./{train}_test_imglist.txt')):  # 判断文件是否存在
        os.remove(f'./{train}_test_imglist.txt')  # 删除文件

    with open(f'./{train}_train_imglist.txt', 'w')as f_train:
        with open(f'./{train}_test_imglist.txt', 'w')as f_test:
            img_list = []
            for character_folder in character_folders:  #  循环文件夹列表  
                character_imgs = os.listdir(os.path.join(data_path,character_folder))  # 读取文件夹下面的内容
                count = 0
                for img in character_imgs:  # 循环图片列表
                    img_list.append(os.path.join(data_path,character_folder,img) + '\
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三岁学编程

感谢支持,更好的作品会继续努力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值