【深度域适配】二、利用DANN实现MNIST和MNIST-M数据集迁移训练

知乎专栏链接:https://zhuanlan.zhihu.com/p/109057360

CSDN链接:https://daipuweiai.blog.csdn.net/article/details/104495520

前言

在前一篇文章【深度域适配】一、DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文:Unsupervised Domain Adaptation by Backpropagation(文章链接:https://arxiv.org/abs/1409.7495)中MNIST和MNIST-M数据集的迁移训练实验。

该项目的github地址为:https://github.com/Daipuwei/DANN-MNIST

一、MNIST和MNIST-M介绍

为了利用DANN实现MNIST和MNIST-M数据集的迁移训练,我们首先需要获取到MNIST和MNIST-M数据集。其中MNIST数据集很容易获取,官网下载链接为:MNSIT。需要下载的文件如下图所示蓝色的4个文件。


同时MNSIT数据集的加载,tensorflow框架已经给出相关的读取接口,因此我们不需要自行编写,读取MNIST数据集的代码如下:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets(os.path.abspath('./dataset/mnist'), one_hot=True)
# Process MNIST
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)

MNIST-M数据集由MNIST数字与BSDS500数据集中的随机色块混合而成。那么要生成MNIST-M数据集,请首先下载BSDS500数据集。BSDS500数据集的官方下载地址为:BSDS500。以下是BSDS500数据集官方网址相关截图,点击下图中蓝框的连接即可下载数据。


下载好BSDS500数据集后,我们必须根据MNIST和BSDS500数据集来生成MNIST-M数据集,生成数据集的脚本create_mnistm.py如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tarfile
import os
import pickle as pkl
import numpy as np
import skimage
import skimage.io
import skimage.transform
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./dataset/mnist')

BST_PATH = os.path.abspath('./dataset/BSR_bsds500.tgz')

rand = np.random.RandomState(42)

f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
    if name.startswith('BSR/BSDS500/data/images/train/'):
        train_files.append(name)

print('Loading BSR training images')
background_data = []
for name in train_files:
    try:
        fp = f.extractfile(name)
        bg_img = skimage.io.imread(fp)
        background_data.append(bg_img)
    except:
        continue


def compose_image(digit, background):
    """Difference-blend a digit and a random patch from a background image."""
    w, h, _ = background.shape
    dw, dh, _ = digit.shape
    x = np.random.randint(0, w - dw)
    y = np.random.randint(0, h - dh)

    bg = background[x:x+dw, y:y+dh]
    return np.abs(bg - digit).astype(np.uint8)


def mnist_to_img(x):
    """Binarize MNIST digit and convert to RGB."""
    x = (x > 0).astype(np.float32)
    d = x.reshape([28, 28, 1]) * 255
    return np.concatenate([d, d, d], 2)


def create_mnistm(X):
    """
    Give an array of MNIST digits, blend random background patches to
    build the MNIST-M dataset as described in
    http://jmlr.org/papers/volume17/15-239/15-239.pdf
    """
    X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
    for i in range(X.shape[0]):

        if i % 1000 == 0:
            print('Processing example', i)

        bg_img = rand.choice(background_data)

        d = mnist_to_img(X[i])
        d = compose_image(d, bg_img)
        X_[i] = d

    return X_


print('Building train set...')
train = create_mnistm(mnist.train.images)
print('Building test set...')
test = create_mnistm(mnist.test.images)
print('Building validation set...')
valid = create_mnistm(mnist.validation.images)

# Save dataset as pickle
mnistm_dir = os.path.abspath("./dataset/mnistm")
if not os.path.exists(mnistm_dir):
    os.mkdir(mnistm_dir)
with open(os.path.join(mnistm_dir,'mnistm_data.pkl'), 'wb') as f:
    pkl.dump({ 'train': train, 'test': test, 'valid': valid }, f, pkl.HIGHEST_PROTOCOL)

二、参数配置类config

由于整个DANN-MNIST网络的训练过程中涉及到很多超参数,因此为了整个项目的编程方便,我们利用面向对象的思想将所有的超参数放置到一个类中,即参数配置类config。这个参数配置类config的代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/15 15:05
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : config.py
# @Software: PyCharm

import os

class config(object):

    __defualt_dict__ = {
        "pre_model_path":None,
        "checkpoints_dir":os.path.abspath("./checkpoints"),
        "logs_dir":os.path.abspath("./logs"),
        "config_dir":os.path.abspath("./config"),
        "dataset_dir": os.path.abspath("./dataset"),
        #"dataset_dir": os.path.abspath("/input0"),
        "result_dir": os.path.abspath("./result"),
        "image_input_shape":(28,28,3),
        "image_size":28,
        "init_learning_rate": 1e-2,
        "momentum_rate": 0.9,
        "batch_size":64,
        "epoch":500,
    }

    def __init__(self,**kwargs):
        """
        这是参数配置类的初始化函数
        :param kwargs: 参数字典
        """
        # 初始化相关配置参数
        self.__dict__.update(self. __defualt_dict__)
        # 根据相关传入参数进行参数更新
        self.__dict__.update(kwargs)

        if not os.path.exists(self.checkpoints_dir):
            os.mkdir(self.checkpoints_dir)

        if not os.path.exists(self.logs_dir):
            os.mkdir(self.logs_dir)

        if not os.path.exists(self.result_dir):
            os.mkdir(self.result_dir)

    def set(self,**kwargs):
        """
        这是参数配置的设置函数
        :param kwargs: 参数字典
        :return:
        """
        # 根据相关传入参数进行参数更新
        self.__dict__.update(kwargs)

    def save_config(self,time):
        """
        这是保存参数配置类的函数
        :param time: 时间点字符串
        :return:
        """
        # 更新相关目录
        self.checkpoints_dir = os.path.join(self.checkpoints_dir,time)
        self.logs_dir = os.path.join(self.logs_dir,time)
        self.config_dir = os.path.join(self.config_dir,time)
        self.result_dir = os.path.join(self.result_dir,time)

        if not os.path.exists(self.config_dir):
            os.mkdir(self.config_dir)
        if not os.path.exists(self.checkpoints_dir):
            os.mkdir(self.checkpoints_dir)
        if not os.path.exists(self.logs_dir):
            os.mkdir(self.logs_dir)
        if not os.path.exists(self.result_dir):
            os.mkdir(self.result_dir)

        config_txt_path = os.path.join(self.config_dir,"config.txt")
        with open(config_txt_path,'a') as f:
            for key,value in self.__dict__.items():
                if key in ["checkpoints_dir","logs_dir","config_dir"]:
                    value = os.path.join(value,time)
                    s = key+": "+value+"\n"
                    f.write(s)


三、梯度反转层(GradientReversalLayer)

在DANN中比较重要的模块就是梯度反转层(Gradient Reversal Layer, GRL)的实现。GRL的tf1.0代码实现如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/14 20:59
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : GRL.py
# @Software: PyCharm

import tensorflow as tf
from tensorflow.python.framework import ops

class GradientReversalLayer(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, l=1.0):
        grad_name = "FlipGradient%d" % self.num_calls

        @ops.RegisterGradient(grad_name)
        def _flip_gradients(op, grad):
            return [tf.negative(grad) * l]

        g = tf.get_default_graph()
        with g.gradient_override_map({"Identity": grad_name}):
            y = tf.identity(x)

        self.num_calls += 1
        return y

在上述代码中@ops.RegisterGradient(grad_name)修饰 _flip_gradients(op, grad)函数,即自定义该层的梯度取反。同时gradient_override_map函数主要用于解决使用自己定义的函数方式来求梯度的问题,gradient_override_map函数的参数值为一个字典。即字典中value表示使用该值表示的函数代替key表示的函数进行梯度运算。

四、 DANN类代码

DANN论文Unsupervised Domain Adaptation by Backpropagation(文章链接为:https://arxiv.org/abs/1409.7495)中给出MNIST和MNIST-M数据集的迁移训练实验的网络,网络架构图如下图所示。

接下来,我们将利用tensorflow1.14.0来搭建整个DANN-MNIST网络,并在使用面向对象思想进行编程。DANN-MNIST类代码如下:
# -*- coding: utf-8 -*-
# @Time    : 2020/2/14 20:27
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : MNIST2MNIST_M.py
# @Software: PyCharm

import os
import cv2
import datetime
import numpy as np
import tensorflow as tf

from tensorflow import keras as K
from tensorflow.train import MomentumOptimizer

from utils.utils import plot_loss
from utils.utils import plot_accuracy
from utils.utils import AverageMeter
from utils.utils import make_summary
from utils.utils import grl_lambda_schedule
from utils.utils import learning_rate_schedule

from model.GRL import GradientReversalLayer as GRL

class MNIST2MNIST_M_DANN(object):

    def __init__(self,config):
        """
        这是MNINST与MNIST_M域适配网络的初始化函数
        :param config: 参数配置类
        """
        # 初始化参数类
        self.cfg = config

        # 定义相关占位符
        self.grl_lambd = tf.placeholder(tf.float32, [])                         # GRL层参数
        self.learning_rate = tf.placeholder(tf.float32, [])                     # 学习率
        self.source_image_labels = tf.placeholder(tf.float32, shape=(None, 10))
        self.domain_labels = tf.placeholder(tf.float32, shape=(None, 2))

        # 搭建深度域适配网络
        self.build_DANN()

        # 定义损失
        self.image_cls_loss =  tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.source_image_labels,
                                                                          logits=self.image_cls))
        self.domain_cls_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.domain_labels,
                                                                        logits=self.domain_cls))
        self.loss = self.image_cls_loss+self.domain_cls_loss

        # 定义精度
        correct_label_pred = tf.equal(tf.argmax(self.source_image_labels, 1), tf.argmax(self.image_cls, 1))
        self.acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))

        # 定义模型保存类与加载类
        self.saver_save = tf.train.Saver(max_to_keep=100)  # 设置最大保存检测点个数为周期数

        # 初始化优化器
        self.global_step = tf.Variable(tf.constant(0), trainable=False)
        self.optimizer = MomentumOptimizer(self.learning_rate, momentum=self.cfg.momentum_rate)
        self.train_op = self.optimizer.minimize(self.loss,global_step=self.global_step)


    def featur_extractor(self,image_input,name):
        """
        这是特征提取子网络的构建函数
        :param image_input: 图像输入张量
        :param name: 输出特征名称
        :return:
        """
        x = K.layers.Conv2D(filters=32,kernel_size=5,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(image_input)
        x = K.layers.MaxPool2D(pool_size=(2,2),strides=2)(x)
        x = K.layers.Conv2D(filters=48, kernel_size=5, kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
        x = K.layers.MaxPool2D(pool_size=(2, 2),strides=2,name=name)(x)
        return x

    def build_image_classify_model(self,image_classify_feature):
        """
        这是搭建图像分类器模型的函数
        :param image_classify_feature: 图像分类特征张量
        :return:
        """
        # 搭建图像分类器
        x = K.layers.Lambda(lambda x:x,name="image_classify_feature")(image_classify_feature)
        x = K.layers.Flatten()(x)
        x = K.layers.Dense(100,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
        #x = K.layers.Dropout(0.5)(x)
        x = K.layers.Dense(10,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='softmax',
                           name = "image_classify_pred")(x)
        return x

    def build_domain_classify_model(self,domain_classify_feature):
        """
        这是搭建域分类器的函数
        :param domain_classify_feature: 域分类特征张量
        :return:
        """
        # 搭建域分类器
        x = GRL(domain_classify_feature,self.grl_lambd)
        x = K.layers.Flatten()(x)
        x = K.layers.Dense(100,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.01),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
        #x = K.layers.Dropout(0.5)(x)
        x = K.layers.Dense(2,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.01),
                                bias_initializer = K.initializers.Constant(value=0.1), activation='softmax'
                           ,name="domain_classify_pred")(x)
        return x

    def build_DANN(self):
        """
        这是搭建域适配网络的函数
        :return:
        """
        # 定义源域、目标域的图像输入和DANN模型图像输入
        self.source_image_input = K.layers.Input(shape=self.cfg.image_input_shape,name="source_image_input")
        self.target_image_input = K.layers.Input(shape=self.cfg.image_input_shape,name="target_image_input")
        self.image_input = K.layers.Concatenate(axis=0,name="image_input")([self.source_image_input,self.target_image_input])
        self.image_input = (self.image_input - self.cfg.pixel_mean) / 255.0

        # 域分类器与图像分类器的共享特征
        share_feature = self.featur_extractor(self.image_input,"image_feature")

        # 均等划分共享特征为源域数据特征与目标域数据特征
        source_feature,target_feature = \
            K.layers.Lambda(tf.split, arguments={'axis': 0, 'num_or_size_splits': 2})(share_feature)
        source_feature = K.layers.Lambda(lambda x:x,name="source_feature")(source_feature)

        # 获取图像分类结果和域分类结果张量
        self.image_cls = self.build_image_classify_model(source_feature)
        self.domain_cls = self.build_domain_classify_model(share_feature)

    def eval_on_val_dataset(self,sess,val_datagen,val_batch_num,ep):
        """
        这是评估模型在验证集上的性能的函数
        :param val_datagen: 验证集数据集生成器
        :param val_batch_num: 验证集数据集批量个数
        """
        epoch_loss_avg = AverageMeter()
        epoch_image_cls_loss_avg = AverageMeter()
        epoch_domain_cls_loss_avg = AverageMeter()
        epoch_accuracy = AverageMeter()
        for i in np.arange(1, val_batch_num + 1):
            # 获取小批量数据集及其图像标签与域标签
            batch_mnist_m_image_data, batch_mnist_m_labels = val_datagen.__next__()#val_datagen.next_batch()
            batch_domain_labels = np.tile([0., 1.], [self.cfg.batch_size * 2, 1])

            # 在验证阶段只利用目标域数据及其标签进行测试,计算模型在验证集上相关指标的值
            val_loss, val_image_cls_loss, val_domain_cls_loss, val_acc = \
                sess.run([self.loss, self.image_cls_loss, self.domain_cls_loss, self.acc],
                        feed_dict={self.source_image_input: batch_mnist_m_image_data,
                                        self.target_image_input: batch_mnist_m_image_data,
                                        self.source_image_labels: batch_mnist_m_labels,
                                        self.domain_labels: batch_domain_labels})
            # 更新损失与精度的平均值
            epoch_loss_avg.update(val_loss, 1)
            epoch_image_cls_loss_avg.update(val_image_cls_loss, 1)
            epoch_domain_cls_loss_avg.update(val_domain_cls_loss, 1)
            epoch_accuracy.update(val_acc, 1)

        self.writer.add_summary(make_summary('val/val_loss', epoch_loss_avg.average),global_step=ep)
        self.writer.add_summary(make_summary('val/val_image_cls_loss', epoch_image_cls_loss_avg.average),global_step=ep)
        self.writer.add_summary(make_summary('val/val_domain_cls_loss', epoch_domain_cls_loss_avg.average),global_step=ep)
        self.writer.add_summary(make_summary('accuracy/val_accuracy', epoch_accuracy.average),global_step=ep)

        return epoch_loss_avg.average,epoch_image_cls_loss_avg.average,\
                   epoch_domain_cls_loss_avg.average,epoch_accuracy.average

    def train(self,train_source_datagen,train_target_datagen,val_datagen,pixel_mean,interval,
              train_iter_num,val_iter_num,pre_model_path=None):
        """
        这是DANN的训练函数
        :param train_source_datagen: 源域训练数据集生成器
        :param train_target_datagen: 目标域训练数据集生成器
        :param val_datagen: 验证数据集生成器
        :param interval: 验证间隔
        :param train_iter_num: 每个epoch的训练次数
        :param val_iter_num: 每次验证过程的验证次数
        :param pre_model_path: 预训练模型地址,与训练模型为ckpt文件,注意文件路径只需到.ckpt即可。
        """
        # 初始化相关文件目录路径
        time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        checkpoint_dir = os.path.join(self.cfg.checkpoints_dir,time)
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)

        log_dir = os.path.join(self.cfg.logs_dir, time)
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)

        result_dir = os.path.join(self.cfg.result_dir, time)
        if not os.path.exists(result_dir):
            os.mkdir(result_dir)

        self.cfg.save_config(time)

        # 初始化训练损失和精度数组
        train_loss_results = []                     # 保存训练loss值
        train_image_cls_loss_results = []           # 保存训练图像分类loss值
        train_domain_cls_loss_results = []          # 保存训练域分类loss值
        train_accuracy_results = []                 # 保存训练accuracy值

        # 初始化验证损失和精度数组,验证最大精度
        val_ep = []
        val_loss_results = []                     # 保存验证loss值
        val_image_cls_loss_results = []           # 保存验证图像分类loss值
        val_domain_cls_loss_results = []          # 保存验证域分类loss值
        val_accuracy_results = []                 # 保存验证accuracy值
        val_acc_max = 0                           # 最大验证精度

        with tf.Session() as sess:
            # 初始化变量
            sess.run(tf.global_variables_initializer())

            # 加载预训练模型
            if pre_model_path is not None:              # pre_model_path的地址写到.ckpt
                saver_restore = tf.train.import_meta_graph(pre_model_path+".meta")
                saver_restore.restore(sess,pre_model_path)
                print("restore model from : %s" % (pre_model_path))

            self.merged = tf.summary.merge_all()
            self.writer = tf.summary.FileWriter(log_dir, sess.graph)

            print('\n----------- start to train -----------\n')

            total_global_step = self.cfg.epoch * train_iter_num
            for ep in np.arange(self.cfg.epoch):
                # 初始化每次迭代的训练损失与精度平均指标类
                epoch_loss_avg = AverageMeter()
                epoch_image_cls_loss_avg = AverageMeter()
                epoch_domain_cls_loss_avg = AverageMeter()
                epoch_accuracy = AverageMeter()

                # 初始化精度条
                progbar = K.utils.Progbar(train_iter_num)
                print('Epoch {}/{}'.format(ep+1, self.cfg.epoch))
                batch_domain_labels = np.vstack([np.tile([1., 0.], [self.cfg.batch_size // 2, 1]),
                                           np.tile([0., 1.], [self.cfg.batch_size // 2, 1])])
                for i in np.arange(1,train_iter_num+1):
                    # 获取小批量数据集及其图像标签与域标签
                    batch_mnist_image_data, batch_mnist_labels = train_source_datagen.__next__()#train_source_datagen.next_batch()
                    batch_mnist_m_image_data, batch_mnist_m_labels = train_target_datagen.__next__()#train_target_datagen.next_batch()

                    # 计算学习率和GRL层的参数lambda
                    global_step = (ep-1)*train_iter_num + i
                    process = global_step * 1.0 / total_global_step
                    leanring_rate = learning_rate_schedule(process,self.cfg.init_learning_rate)
                    grl_lambda = grl_lambda_schedule(process)

                    # 前向传播,计算损失及其梯度
                    op,train_loss,train_image_cls_loss,train_domain_cls_loss,train_acc = \
                        sess.run([self.train_op,self.loss,self.image_cls_loss,self.domain_cls_loss,self.acc],
                                  feed_dict={self.source_image_input:batch_mnist_image_data,
                                             self.target_image_input:batch_mnist_m_image_data,
                                             self.source_image_labels:batch_mnist_labels,
                                             self.domain_labels:batch_domain_labels,
                                             self.learning_rate:leanring_rate,
                                             self.grl_lambd:grl_lambda})
                    self.writer.add_summary(make_summary('learning_rate', leanring_rate),global_step=global_step)
                    self.writer1.add_summary(make_summary('learning_rate', leanring_rate), global_step=global_step)

                    # 更新训练损失与训练精度
                    epoch_loss_avg.update(train_loss,1)
                    epoch_image_cls_loss_avg.update(train_image_cls_loss,1)
                    epoch_domain_cls_loss_avg.update(train_domain_cls_loss,1)
                    epoch_accuracy.update(train_acc,1)

                    # 更新进度条
                    progbar.update(i, [('train_image_cls_loss', train_image_cls_loss),
                                       ('train_domain_cls_loss', train_domain_cls_loss),
                                       ('train_loss', train_loss),
                                       ("train_acc",train_acc)])

                # 保存相关损失与精度值,可用于可视化
                train_loss_results.append(epoch_loss_avg.average)
                train_image_cls_loss_results.append(epoch_image_cls_loss_avg.average)
                train_domain_cls_loss_results.append(epoch_domain_cls_loss_avg.average)
                train_accuracy_results.append(epoch_accuracy.average)

                self.writer.add_summary(make_summary('train/train_loss', epoch_loss_avg.average),global_step=ep+1)
                self.writer.add_summary(make_summary('train/train_image_cls_loss', epoch_image_cls_loss_avg.average),
                                   global_step=ep+1)
                self.writer.add_summary(make_summary('train/train_domain_cls_loss', epoch_domain_cls_loss_avg.average),
                                   global_step=ep+1)
                self.writer.add_summary(make_summary('accuracy/train_accuracy', epoch_accuracy.average),global_step=ep+1)

                if (ep+1) % interval == 0:
                    # 评估模型在验证集上的性能
                    val_ep.append(ep)
                    val_loss, val_image_cls_loss,val_domain_cls_loss, \
                        val_accuracy = self.eval_on_val_dataset(sess,val_datagen,val_iter_num,ep+1)
                    val_loss_results.append(val_loss)
                    val_image_cls_loss_results.append(val_image_cls_loss)
                    val_domain_cls_loss_results.append(val_domain_cls_loss)
                    val_accuracy_results.append(val_accuracy)
                    str =  "Epoch {:03d}: val_image_cls_loss: {:.3f}, val_domain_cls_loss: {:.3f}, val_loss: {:.3f}" \
                           ", val_accuracy: {:.3%}".format(ep+1,val_image_cls_loss,val_domain_cls_loss,val_loss,val_accuracy)
                    print(str)

                    if val_accuracy > val_acc_max:              # 验证精度达到当前最大,保存模型
                        val_acc_max = val_accuracy
                        self.saver_save.save(sess,os.path.join(checkpoint_dir,str+".ckpt"))

            # 保存训练与验证结果
            path = os.path.join(result_dir, "train_loss.jpg")
            plot_loss(np.arange(1,len(train_loss_results)+1), [np.array(train_loss_results),
                                np.array(train_image_cls_loss_results),np.array(train_domain_cls_loss_results)],
                                path, "train")
            path = os.path.join(result_dir, "val_loss.jpg")
            plot_loss(np.array(val_ep)+1, [np.array(val_loss_results),
                                np.array(val_image_cls_loss_results),np.array(val_domain_cls_loss_results)],
                               path, "val")
            train_acc = np.array(train_accuracy_results)[np.array(val_ep)]
            path = os.path.join(result_dir, "accuracy.jpg")
            plot_accuracy(np.array(val_ep)+1, [train_acc, val_accuracy_results], path)

            # 保存最终的模型
            model_path = os.path.join(checkpoint_dir,"trained_model.ckpt")
            self.saver_save.save(sess,model_path)
            print("Train model finshed. The model is saved in : ", model_path)
            print('\n----------- end to train -----------\n')

    def test_image(self,image_path,model_path):
        """
        这是测试一张图像的函数
        :param image_path: 图像路径
        :param model_path: 模型路径
        :return:
        """
        # 读取图像数据,并进行数组维度扩充
        image = cv2.imread(image_path)
        image = np.expand_dims(image,axis=0)
        image = (image - self.cfg.val_image_mean) / 255.0

        with tf.Session() as sess:
            # 初始化变量
            sess.run(tf.global_variables_initializer())

            # 加载预训练模型
            saver_restore = tf.train.import_meta_graph(model_path+".meta")
            saver_restore.restore(sess, model_path)

            # 进行测试
            img_cls_pred = sess.run([self.image_cls],feed_dict={self.source_image_input: image})
            pred_label = np.argmax(img_cls_pred[0])+1
            print("%s is %d" %(image_path,pred_label))

    def test_batch_images(self, image_paths, model_path):
        """
        这是测试一张图像的函数
        :param image_paths: 图像路径数组
        :param model_path: 模型路径
        :return:
        """
        # 批量读取图像数据
        images = np.array([cv2.imread(image_path) for image_path in image_paths])
        images = (images - self.cfg.val_image_mean) / 255.0

        with tf.Session() as sess:
            # 初始化变量
            sess.run(tf.global_variables_initializer())

            # 加载预训练模型
            saver_restore = tf.train.import_meta_graph(model_path+".meta")
            saver_restore.restore(sess, model_path)

            # 进行测试
            img_cls_pred = sess.run([self.image_cls], feed_dict={self.source_image_input: images})
            pred_label = np.argmax(img_cls_pred,axis=0) + 1
            for i,image_path in enumerate(image_paths):
                print("%s is %d" % (image_path, pred_label[i]))

五、工具脚本utilis

在训练过程中,需要各种小工具函数来辅助训练过程。例如学习率、GRL参数是根据迭代进程变化,数据集生成器的定义和各种结果绘制函数。工具脚本utilis.py如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/15 16:10
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : utils.py
# @Software: PyCharm

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.core.framework import summary_pb2

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.average = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.average = self.sum / float(self.count)

def make_summary(name, val):
    return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)])

def plot_accuracy(x,y,path):
    """
    这是绘制精度的函数
    :param x: x坐标数组
    :param y: y坐标数组
    :param path: 结果保存地址
    :param mode: 模式,“train”代表训练损失,“val”为验证损失
    """
    lengend_array = ["train_acc", "val_acc"]
    train_accuracy,val_accuracy = y
    plt.plot(x, train_accuracy, 'r-')
    plt.plot(x, val_accuracy, 'b--')
    plt.grid(True)
    plt.xlim(0, x[-1]+2)
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    plt.legend(lengend_array,loc="best")
    plt.savefig(path)
    plt.close()

def plot_loss(x,y,path,mode="train"):
    """
    这是绘制损失的函数
    :param x: x坐标数组
    :param y: y坐标数组
    :param path: 结果保存地址
    :param mode: 模式,“train”代表训练损失,“val”为验证损失
    """
    if mode == "train":
        lengend_array = ["train_loss","train_image_cls_loss","train_domain_cls_loss"]
    else:
        lengend_array = ["val_loss", "val_image_cls_loss", "val_domain_cls_loss"]
    loss_results,image_cls_loss_results,domain_cls_loss_results = y
    loss_results_min = np.max([np.min(loss_results) - 0.1,0])
    image_cls_loss_results_min = np.max([np.min(image_cls_loss_results) - 0.1,0])
    domain_cls_loss_results_min =np.max([np.min(domain_cls_loss_results) - 0.1,0])
    y_min = np.min([loss_results_min,image_cls_loss_results_min,domain_cls_loss_results_min])
    plt.plot(x, loss_results, 'r-')
    plt.plot(x, image_cls_loss_results, 'b--')
    plt.plot(x, domain_cls_loss_results, 'g-.')
    plt.grid(True)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.xlim(0,x[-1]+2)
    plt.ylim(ymin=y_min)
    plt.legend(lengend_array,loc="best")
    plt.savefig(path)
    plt.close()

def shuffle_aligned_list(data):
    """
    这是是随机打乱数据的函数
    :param data: 输入数据
    :return:
    """
    num = data[0].shape[0]
    p = np.random.permutation(num)
    return [d[p] for d in data]

def batch_generator(data, batch_size, shuffle=True):
    """
    这是构造数据生成器的函数
    :param data: 输入
    :param batch_size: 小批量大小
    :param shuffle: 是否打乱随机数据集的标志
    :return:
    """
    if shuffle:             # 随机打乱数据集标志为True,则随机打乱数据集
        data = shuffle_aligned_list(data)

    batch_count = 0         # 小批量数据集批次计数器
    while True:
        # 遍历完整个数据集,全部重置
        if batch_count * batch_size + batch_size >= len(data[0]):
            batch_count = 0

            if shuffle:          # 随机打乱数据集标志为True,则随机打乱数据集
                data = shuffle_aligned_list(data)

        # 构造小批量数据集
        start = batch_count * batch_size
        end = start + batch_size
        batch_count += 1
        yield [d[start:end] for d in data]          # 构造数据生成器

def learning_rate_schedule(process,init_learning_rate = 0.01,alpha = 10.0 , beta = 0.75):
    """
    这个学习率的变换函数
    :param process: 训练进程比率,值在0-1之间
    :param init_learning_rate: 初始学习率,默认为0.01
    :param alpha: 参数alpha,默认为10
    :param beta: 参数beta,默认为0.75
    """
    return init_learning_rate /(1.0 + alpha * process)**beta

def grl_lambda_schedule(process,gamma=10.0):
    """
    这是GRL的参数lambda的变换函数
    :param process: 训练进程比率,值在0-1之间
    :param gamma: 参数gamma,默认为10
    """
    return 2.0 / (1.0+np.exp(-gamma*process)) - 1.0

六、训练过程与实验结果

最后是训练DANN的脚本train.py,代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2020/2/15 16:36
# @Author  : Dai PuWei
# @Email   : 771830171@qq.com
# @File    : train.py
# @Software: PyCharm

import os
import numpy as np
import pickle as pkl

from config.config import config
from model.MNIST2MNIST_M import MNIST2MNIST_M_DANN
from tensorflow.examples.tutorials.mnist import input_data
from utils.utils import batch_generator

def run_main():
    """
       这是主函数
    """
    # 初始化参数配置类
    cfg = config()

    mnist = input_data.read_data_sets(os.path.abspath('./dataset/mnist'), one_hot=True)
    # Process MNIST
    mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
    mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
    mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
    mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)

    # Load MNIST-M
    mnistm = pkl.load(open(os.path.abspath('./dataset/mnistm/mnistm_data.pkl'), 'rb'))
    mnistm_train = mnistm['train']
    mnistm_test = mnistm['test']
    mnistm_valid = mnistm['valid']

    # Compute pixel mean for normalizing data
    pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))
    cfg.set(pixel_mean = pixel_mean)

    # 构造数据生成器
    train_source_datagen = batch_generator([mnist_train,mnist.train.labels],cfg.batch_size // 2)
    train_target_datagen = batch_generator([mnistm_train,mnist.train.labels],cfg.batch_size // 2)
    val_datagen = batch_generator([mnistm_test,mnist.test.labels],cfg.batch_size)

    # 初始化每个epoch的训练次数和每次验证过程的验证次数
    train_source_batch_num = int(len(mnist_train) // (cfg.batch_size // 2))
    train_target_batch_num = int(len(mnistm_train) // (cfg.batch_size // 2))
    train_iter_num = int(np.max([train_source_batch_num,train_target_batch_num]))
    val_iter_num = int(len(mnistm_test) / cfg.batch_size)

    # 初始化相关参数
    interval = 2  # 验证间隔
    train_num = len(mnist_train) +  len(mnistm_train)# 训练集样本数
    val_num = len(mnistm_test)     # 验证集样本数
    print("train on %d training samples with batch_size %d ,validation on %d val samples"
          % (train_num, cfg.batch_size, val_num))

    # 初始化DANN,并进行训练
    dann = MNIST2MNIST_M_DANN(cfg)
    #pre_model_path = os.path.abspath("./pre_model/trained_model.ckpt")
    pre_model_path = None
    dann.train(train_source_datagen,train_target_datagen,val_datagen,pixel_mean,
               interval,train_iter_num,val_iter_num,pre_model_path)

if __name__ == '__main__':
    run_main()

下面是训练过程中的相关tensorboard的相关指标在训练过程中的走势图。首先是训练误差的走势图,主要包括训练域分类误差、训练图像分类误差和训练总误差。


接下来是验证误差的走势图,主要包括验证域分类误差、验证图像分类误差和验证总误差。

然后是训练过程中学习率的走势图

最后是精度走势图,主要包括训练精度和测试精度。 其中训练精度是在源域数据集即MNIST数据集上的统计结果,验证精度是在目标域数据集即MNIST-M数据集上的统计结果。从图中可以看出,DANN在训练MNIST-M数据集时没有使用对应的标签,MNSIT-M数据集上的精度最终收敛到75.4%,效果相比于81.49%还有一定距离,但鉴于没有使用任何数据增强和dropout,这个结果可以接受。

公众号近期荐读:

GAN整整6年了!是时候要来捋捋了! 

新手指南综述 | GAN模型太多,不知道选哪儿个?

数百篇GAN论文已下载好!搭配一份生成对抗网络最新综述!

有点夸张、有点扭曲!速览这些GAN如何夸张漫画化人脸!

天降斯雨,于我却无!GAN用于去雨如何?

脸部转正!GAN能否让侧颜杀手、小猪佩奇真容无处遁形?

容颜渐失!GAN来预测?

强数据所难!SSL(半监督学习)结合GAN如何?

弱水三千,只取你标!AL(主动学习)结合GAN如何?

异常检测,GAN如何gan ?

虚拟换衣!速览这几篇最新论文咋做的!

脸部妆容迁移!速览几篇用GAN来做的论文

【1】GAN在医学图像上的生成,今如何?

01-GAN公式简明原理之铁甲小宝篇


GAN&CV交流群,无论小白还是大佬,诚挚邀您加入!

一起讨论交流!长按备注【进群】加入:

更多分享、长按关注本公众号:

  • 6
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值