李宏毅机器学习作业11——Transfer Learning,Domain Adversarial Training

本文介绍了李宏毅机器学习课程中关于Transfer Learning的Domain Adversarial Training(DaNN)方法。任务是使用source data训练模型,并应用于target data。通过Canny Edge Detection处理source data,采用VGG-like的Feature Extractor和Domain Classifier,利用Gradient Reversal Layer进行训练。实验中,lambda值的调整影响模型性能,过高可能影响Label Predictor。文章还探讨了动态调整lambda的策略以及利用DANN模型生成伪标签进行进一步学习。
摘要由CSDN通过智能技术生成

Domain Adversarial Training见:

​李宏毅机器学习——领域适应Domain Adaptation_iwill323的博客-CSDN博客_领域适应

迁移学习参见2022CS231n PPT笔记 - 迁移学习_iwill323的博客-CSDN博客_cs231n ppt

目录

任务和数据集

任务

数据集

方法论:DaNN

导包

数据处理

显示图片

Canny Edge Detection

transforms

dataset和数据加载

模型

训练

训练函数

Gradient Reversal Layer

一个问题

lambda

推断

可视化

解答

lambda = 0.1

lambda = 0.7

adaptive lambda

利用DANN模型生成伪标签

任务和数据集

任务

這份作業的任務是Transfer Learning中的Domain Adversarial Training。也就是左下角的那一塊。

现在拥有标注的source data和未标注的target data,其中source data可能和target data有一定的联系。我们想仅使用 source data 训练一个模型,然后用在target data的推断。

数据集

数据来源 here

使用photos和the labels进行训练,预测hand-drawn graffiti的类别

● Label: 10 classes (numbered from 0 to 9).
● Training (source data): 5000 (32, 32) RGB 真实照片(with label).
● Testing (target data): 100000 (28, 28) gray scale 手写图像

 文件结构如下图:

方法论:DaNN

DaNN的核心:讓Soucre Data和Target Data經過Feature Extractor映射在同個Distribution上,这样在source domain上训练的classifier可以用于target domain

如何讓前半段的模型輸入兩種不同分布的資料,輸出卻是同個分布呢?

最簡單的方法就是像 GAN 一樣導入一個discriminator,这里叫Domain Classifier,讓它判斷經過Feature Extractor後的Feature是源自於哪個domain,讓Feature Extractor學習如何產生Feature以騙過Domain Classifier。 持久下來,通常Feature Extractor都會打贏Domain Classifier,因為Domain Classifier的Input來自於Feature Extractor,而且對Feature Extractor來說domain classification和label classification的任務並沒有衝突。

如此一來,我們就可以確信不管是哪一個Domain,Feature Extractor都會把它產生在同一個Feature Distribution上。

导包

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
 
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os
from d2l import torch as d2l

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

数据处理

显示图片

def no_axis_show(img, title='', cmap=None):
  # imshow, and set the interpolation mode to be "nearest"。
  fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
  # do not show the axes in the images.
  fig.axes.get_xaxis().set_visible(False)
  fig.axes.get_yaxis().set_visible(False)
  plt.title(title)

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):
  plt.subplot(1, 10, i+1)
  fig = no_axis_show(plt.imread(f'/kaggle/input/ml2022-spring-hw11/real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])

plt.figure(figsize=(18, 18))
for i in range(10):
  plt.subplot(1, 10, i+1)
  fig = no_axis_show(plt.imread(f'/kaggle/input/ml2022-spring-hw11/real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))

 

Canny Edge Detection

对于这个任务,我们有一个special domain knowledge:塗鴉的時候通常只會畫輪廓,我們可以根據這點將source data做點邊緣偵測處理,讓source data更像target data一點。

用cv2.Canny做Canny Edge Detection非常方便,只需要兩個參數: low_threshold, high_threshold。

cv2.Canny(image, low_threshold, high_threshold)

簡單來說就是當邊緣值超過high_threshold,就確定它是edge。如果只有超過low_threshold,那就先判斷一下再決定是不是edge。下面对source data做Canny Edge Detection

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(12, 12))

original_img = plt.imread(f'/kaggle/input/ml2022-spring-hw11/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')

canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')

canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
  

 low_threshold, high_threshold大小可调,他们越大,轮廓曲线就越少

transforms

 这里使用了lambda表达式

source_transform = transforms.Compose([
    # Turn RGB to grayscale. (Bacause Canny do not support RGB images.)
    transforms.Grayscale(),
    # cv2 do not support skimage.Image, so we transform it to np.array, 
    # and then adopt cv2.Canny algorithm.
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    # Transform np.array back to the skimage.Image.
    transforms.ToPILImage(),
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixel after rotation.
    transforms.RandomRotation(15, fill=(0,)),
    # Transform to tensor for model inputs.
    transforms.ToTensor(),
])

target_transform = transforms.Compose([
    # Turn RGB to grayscale.
    transforms.Grayscale(),
    # Resize: size of source data is 32x32, thus we need to enlarge the size of target data from 28x28 to 32x32。
    transforms.Resize((32, 32)),
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixe
  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值