基于飞桨框架复现跨模态检索模型DADH

★★★ 本文源自AI Studio社区精品项目,【点击此处】查看更多精品内容 >>>


基于飞桨框架复现模型DADH

  • 参考github-pytorch代码: https://github.com/Zjut-MultimediaPlus/DADH

简介

  1. 概述
  • 跨模态哈希由于其优越的检索效率和较低的存储成本而受到跨模态检索任务的广泛关注。然而,大多数现有的跨模态哈希方法直接从多媒体数据中学习二进制代码,不能充分利用数据的语义知识。此外,他们无法学习具有多标签的数据点的基于排序的相似性相关性。并且他们通常使用哈希码的放松约束,这在优化中导致不可忽略的量化损失。在本文中,提出了一种称为深度对抗离散哈希(DADH)的哈希方法来解决跨模态检索中的这些问题。所提出的方法使用对抗性训练来学习跨模态的特征,并确保跨模态特征表示的分布一致性。引入了加权余弦三元组约束,该约束可以充分利用来自多标签的语义知识,以确保项目对的精确排名相关性。此外,使用离散哈希策略来学习不松弛的离散二进制码,这样可以保留哈希码中标签的语义知识,同时最小化量化损失。在两个跨模态数据库上的烧蚀实验和比较实验表明,所提出的DADH提高了性能,并优于几种当时最先进的跨模态检索哈希方法。
  1. 来源
  • ICMR 2020 paper “Deep Adversarial Discrete Hashing for Cross-Modal Retrieval”.
  1. 设计特点
  • 在特征学习和哈希学习中都引入了对抗性学习,这可以更好地对齐特征表示和哈希码在模态之间的分布。
  • 提出了一种可以在不放松和量化损失最小化的情况下学习二进制代码的策略,并设计了一个具有加权余弦三元组边缘约束的哈希网络,该网络不仅可以保留多标签语义信息,而且可以保证跨模态实例之间的高质量排序相似性关系。
  • 在两个广泛使用的多标签数据库上进行评估的。实验结果表明,它优于几种最先进的方法。此外,通过烧蚀实验证明了对抗性学习和哈希学习策略可以显著提高检索性能。
  1. 网络结构
  • DADH框架。DADH是一个端到端框架,由两部分组成:(1)特征学习模块,由两个不同的图像和文本深度网络组成。将两个深度网络作为生成器来混淆鉴别器,而鉴别器试图区分特征的模态。(2) 哈希学习模块还引入了对抗性学习,其中使用两个不同的生成器从不同的模态生成哈希码,其中鉴别器尝试区分哈希码的模态。最后,通过离散哈希学习策略求解哈希码

.

代码

准备工作

  • 固定随机数,定义相应函数
import paddle
import os
import numpy as np
import random
from paddle.io import DataLoader
from tqdm import tqdm
from datasets.dataset import Dataset
from config import opt
from models.dis_model import DIS
from models.gen_model import GEN
from triplet_loss import *
from paddle.optimizer import Adam
from utils import calc_map_k
from datasets.data_handler import load_data, load_pretrain_model
import time
import pickle
import h5py

def load_flickr25k(path):
    file = h5py.File(path)
    images = file['images'][:].transpose(1,0)
    images = (images - images.mean()) / images.std()
    labels = file['LAll'][:].transpose(1,0)
    tags = file['YAll'][:].transpose(1,0)

    file.close()
    return images, tags, labels
    
def setup_seed(seed):
    paddle.seed(seed)
    np.random.seed(seed)
    random.seed(seed)
setup_seed(20)

def generate_img_code(model, test_dataloader, num):
    B = paddle.zeros([num, opt.bit]) 

    for i, input_data in enumerate(test_dataloader()):
        input_data = input_data[0] 
        b = model.generate_img_code(input_data)
        idx_end = min(num, (i + 1) * opt.batch_size)
        B[i * opt.batch_size: idx_end, :] = b.numpy()

    B = paddle.sign(B)
    return B

def generate_txt_code(model, test_dataloader, num):
    B = paddle.zeros([num, opt.bit]) 

    for i, input_data in enumerate(test_dataloader()):
        input_data = input_data[0]
        b = model.generate_txt_code(input_data)
        idx_end = min(num, (i + 1) * opt.batch_size)
        B[i * opt.batch_size: idx_end, :] = b.numpy()

    B = paddle.sign(B)
    return B

def load_model(model, path):
    if path is not None:
        model.load(os.path.join(path, model.module_name + '.pth'))

def save_model(model):
    path = 'checkpoint/' + opt.dataset + '_' + str(opt.bit)
    # paddle.save(optimizer.state_dict(), "adam.pdopt")
    paddle.save(model.state_dict(),path+'.pdparams')

def valid(model, x_query_dataloader, x_db_dataloader, y_query_dataloader, y_db_dataloader,
          query_labels, db_labels):
    model.eval()

    qBX = generate_img_code(model, x_query_dataloader, opt.query_size)
    qBY = generate_txt_code(model, y_query_dataloader, opt.query_size)
    rBX = generate_img_code(model, x_db_dataloader, opt.db_size)
    rBY = generate_txt_code(model, y_db_dataloader, opt.db_size)

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)

    model.train()
    return mapi2t.item(), mapt2i.item()

def test():

    if opt.device is None or opt.device is 'cpu':
        paddle.device.set_device("cpu")
    else:
        paddle.device.set_device("gpu")

    pretrain_model = load_pretrain_model(opt.pretrain_model_path)

    generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim, opt.bit, pretrain_model=pretrain_model) 
    

    path = 'checkpoints/' + opt.dataset + '_' + str(opt.bit)
    load_model(generator, path)

    generator.eval()

    images, tags, labels = load_data(opt.data_path, opt.dataset)

    i_query_data = Dataset(opt, images, tags, labels, test='image.query')
    i_db_data = Dataset(opt, images, tags, labels, test='image.db')
    t_query_data = Dataset(opt, images, tags, labels, test='text.query')
    t_db_data = Dataset(opt, images, tags, labels, test='text.db')

    i_query_dataloader = DataLoader(i_query_data, opt.batch_size, shuffle=False)
    i_db_dataloader = DataLoader(i_db_data, opt.batch_size, shuffle=False)
    t_query_dataloader = DataLoader(t_query_data, opt.batch_size, shuffle=False)
    t_db_dataloader = DataLoader(t_db_data, opt.batch_size, shuffle=False)

    qBX = generate_img_code(generator, i_query_dataloader, opt.query_size)
    qBY = generate_txt_code(generator, t_query_dataloader, opt.query_size)
    rBX = generate_img_code(generator, i_db_dataloader, opt.db_size)
    rBY = generate_txt_code(generator, t_db_dataloader, opt.db_size)

    query_labels, db_labels = i_query_data.get_labels()
    query_labels = query_labels 
    db_labels = db_labels 

    mapi2t = calc_map_k(qBX, rBY, query_labels, db_labels)
    mapt2i = calc_map_k(qBY, rBX, query_labels, db_labels)
    print('...test MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))

读取数据集

  • 该数据集flickr25k经过VGG-F网络预处理,mat格式,含images(共20015张图片,单张图片已经过VGG-F网络预处理为4096的特征向量)、LAll(label)、YAll(tags)
  • 封装成dataloder
import paddle
import warnings
warnings.filterwarnings('ignore')
opt.device = 'gpu' 
opt.vis_env = None 
pretrain_model = None

opt.beta = opt.beta + 0.1

if opt.device is None or opt.device is 'cpu':
    paddle.device.set_device("cpu")
else:
    paddle.device.set_device("gpu")

opt.data_path = r"/home/aistudio/data/data117641/data.mat"
images, tags, labels = load_flickr25k(opt.data_path)

train_data = Dataset(opt, images, tags, labels)
train_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)
L = train_data.get_labels()
# test
i_query_data = Dataset(opt, images, tags, labels, test='image.query')
i_db_data = Dataset(opt, images, tags, labels, test='image.db')
t_query_data = Dataset(opt, images, tags, labels, test='text.query')
t_db_data = Dataset(opt, images, tags, labels, test='text.db')

i_query_dataloader= DataLoader(i_query_data,
                    batch_size=opt.batch_size,
                    shuffle=False,drop_last=True,
                    num_workers=2)
i_db_dataloader= DataLoader(i_db_data,
                    batch_size=opt.batch_size,
                    shuffle=False,drop_last=True,
                    num_workers=2)
t_query_dataloader= DataLoader(t_query_data,
                    batch_size=opt.batch_size,
                    shuffle=False,drop_last=True,
                    num_workers=2)
t_db_dataloader= DataLoader(t_db_data,
                    batch_size=opt.batch_size,
                    shuffle=False,drop_last=True,
                    num_workers=2)
query_labels, db_labels = i_query_data.get_labels()
query_labels = query_labels
db_labels = db_labels

generator = GEN(opt.dropout, opt.image_dim, opt.text_dim, opt.hidden_dim,
                opt.bit, pretrain_model=pretrain_model)

discriminator = DIS(opt.hidden_dim//4, opt.hidden_dim //8, opt.bit)

optimizer = Adam(parameters=generator.parameters(), learning_rate=opt.lr, weight_decay=0.0005)

optimizer_dis = {
    'feature': Adam(parameters=discriminator.feature_dis.parameters(), learning_rate=opt.lr,  beta1=0.5, beta2=0.9, weight_decay=0.0001),
    'hash': Adam(parameters=discriminator.hash_dis.parameters(), learning_rate=opt.lr, beta1=0.5, beta2=0.9, weight_decay=0.0001)
}

tri_loss = TripletLoss(opt, reduction='sum')

loss = []

max_mapi2t = 0.
max_mapt2i = 0.
max_average = 0.

mapt2i_list = []
mapi2t_list = []
train_times = []

B_i = paddle.randn((opt.training_size, opt.bit)).sign()
B_t = B_i
H_i = paddle.zeros((opt.training_size, opt.bit))
H_t = paddle.zeros((opt.training_size, opt.bit))

训练

  • DADH两部分组成:

(1)特征学习模块,由两个不同的图像和文本深度网络组成。将两个深度网络作为生成器来混淆鉴别器,而鉴别器试图区分特征的模态。

(2)哈希学习模块还引入了对抗性学习,其中使用两个不同的生成器从不同的模态生成哈希码,其中鉴别器尝试区分哈希码的模态。最后,通过离散哈希学习策略求解哈希码

  • 值得注意的是,哈希学习阶段的对抗性学习只会更新哈希网络的参数。
from paddle.nn.functional import sigmoid
import warnings
warnings.filterwarnings('ignore')
opt.max_epoch = 150 
# opt.batch_size = 256
for epoch in range(opt.max_epoch):
    t1 = time.time()
    e_loss = 0
    for i, (ind, img, txt, label) in tqdm(enumerate(train_dataloader)):
        imgs = img
        txt = txt
        labels = label

        batch_size = len(ind)

        h_i, h_t, f_i, f_t = generator(imgs, txt)
        H_i[ind, :] = h_i.numpy()
        H_t[ind, :] = h_t.numpy()
        h_t_detach = generator.generate_txt_code(txt)

        #####
        # train feature discriminator---------------------------------------------
        #####
        D_real_feature = discriminator.dis_feature(f_i.detach())
        
        D_real_feature = -opt.gamma * \
            paddle.log(sigmoid(D_real_feature)).mean() 
        # D_real_feature = -D_real_feature.mean()
        optimizer_dis['feature'].clear_grad()
        D_real_feature.backward() 

        # train with fake
        D_fake_feature = discriminator.dis_feature(f_t.detach())
        D_fake_feature = -opt.gamma * \
            paddle.log(paddle.ones([batch_size]) -
                      sigmoid(D_fake_feature)).mean()
        # D_fake_feature = D_fake_feature.mean()
        D_fake_feature.backward()

        # train with gradient penalty
        alpha = paddle.rand((batch_size, opt.hidden_dim//4))
        
        
        interpolates = alpha * f_i.detach() + (1 - alpha) * f_t.detach()
        interpolates = paddle.to_tensor(interpolates,stop_gradient=False)
        disc_interpolates = discriminator.dis_feature(interpolates)
        gradients = paddle.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=paddle.ones(disc_interpolates.shape),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.reshape([gradients.shape[0], -1])
        # 10 is gradient penalty hyperparameter
        feature_gradient_penalty = (
            (gradients.norm(2, axis=1) - 1) ** 2).mean() * 10
        feature_gradient_penalty.backward()

        optimizer_dis['feature'].step()

        #####
        # train hash discriminator---------------------------------------------
        #####
        D_real_hash = discriminator.dis_hash(h_i.detach())
        D_real_hash = -opt.gamma * paddle.log((sigmoid(D_real_hash))).mean()
        optimizer_dis['hash'].clear_grad()
        D_real_hash.backward()

        # train with fake
        D_fake_hash = discriminator.dis_hash(h_t.detach())
        D_fake_hash = -opt.gamma * \
            paddle.log(paddle.ones([batch_size]) -
                      sigmoid(D_fake_hash)).mean()
        D_fake_hash.backward()

        # train with gradient penalty
        alpha = paddle.rand([batch_size, opt.bit])
        interpolates = alpha * h_i.detach() + (1 - alpha) * h_t.detach()
        interpolates= paddle.to_tensor(interpolates,stop_gradient=False)
        disc_interpolates = discriminator.dis_hash(interpolates)
        gradients = paddle.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=paddle.ones(disc_interpolates.shape),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.reshape([gradients.shape[0], -1])

        hash_gradient_penalty = (
            (gradients.norm(2, axis=1) - 1) ** 2).mean() * 10
        hash_gradient_penalty.backward()

        optimizer_dis['hash'].step()

        
        loss_G_txt_feature = - \
            paddle.log(sigmoid(discriminator.dis_feature(f_t))).mean()
        loss_adver_feature = loss_G_txt_feature

        loss_G_txt_hash = - \
            paddle.log(sigmoid(discriminator.dis_hash(h_t_detach))).mean()
        loss_adver_hash = loss_G_txt_hash

        tri_i2t = tri_loss(h_i, labels, target=h_t, margin=opt.margin)
        tri_t2i = tri_loss(h_t, labels, target=h_i, margin=opt.margin)
        weighted_cos_tri = tri_i2t + tri_t2i

        i_ql = paddle.sum(paddle.pow(B_i[ind] - h_i, 2))
        t_ql = paddle.sum(paddle.pow(B_i[ind] - h_t, 2))
        loss_quant = i_ql + t_ql
        err = opt.alpha * weighted_cos_tri + \
            opt.beta * loss_quant + opt.gamma * \
            (loss_adver_feature + loss_adver_hash)
        # print(err)

        optimizer.clear_grad()
        err.backward()
        optimizer.step()

        e_loss = err + e_loss

    P_i = paddle.inverse(
        L.t() @ L + opt.lamb * paddle.eye(opt.num_label)) @ L.t() @ B_i

    B_i = (L @ P_i + 0.5 * opt.mu * (H_i + H_t)).sign()
    # B_t = (L @ P_t + opt.mu * H_t).sign()
    loss.append(e_loss.item())
    print('...epoch: %3d, loss: %3.3f' % (epoch + 1, loss[-1]))
    delta_t = time.time() - t1

    if opt.vis_env:
        vis.plot('loss', loss[-1])

    # validate
    if opt.valid and (epoch + 1) % opt.valid_freq == 0:
        mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                               query_labels, db_labels)
        print('...epoch: %3d, valid MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
              (epoch + 1, mapi2t, mapt2i))

        mapi2t_list.append(mapi2t)
        mapt2i_list.append(mapt2i)
        train_times.append(delta_t)

        if 0.5 * (mapi2t + mapt2i) > max_average:
            max_mapi2t = mapi2t
            max_mapt2i = mapt2i
            max_average = 0.5 * (mapi2t + mapt2i)
            save_model(generator)

        if opt.vis_env:
            vis.plot('mapi2t', mapi2t)
            vis.plot('mapt2i', mapt2i)

    if epoch % 100 == 0:
        optimizer.set_lr(max(optimizer.get_lr() * 0.8, 1e-6))
        

if not opt.valid:
    save_model(generator)

print('...training procedure finish')
if opt.valid:
    print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' %
          (max_mapi2t, max_mapt2i))
else:
    mapi2t, mapt2i = valid(generator, i_query_dataloader, i_db_dataloader, t_query_dataloader, t_db_dataloader,
                           query_labels, db_labels)
    print('   max MAP: MAP(i->t): %3.4f, MAP(t->i): %3.4f' % (mapi2t, mapt2i))

path = 'checkpoint/' + opt.dataset + '_epoch_' + str(opt.max_epoch)+'_' + str(opt.bit) 
with open(path + '_result.pkl', 'wb') as f:
    pickle.dump([train_times, mapi2t_list, mapt2i_list], f)

训练结果

  • 经过150轮进行模型训练,结果如下:
  • 复现跨模态检索模型DADH达到在flickr25k数据集目标:64bits哈希码程度下,image->tag 81.79%,tag->image 80.64%:
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(mapi2t_list,label='I->T')
plt.plot(mapt2i_list,label='T->I')
plt.legend()
plt.show()

在这里插入图片描述

复现总结

论文中将图片提前通过固定参数的VGG-F模型处理成向量,再将向量作为模型输入是个好做法,能够极大减少处理图片的时间。

但这种方式有可能会丢失图片部分重要的特征,也会使模型的分类精度取决上一个模型处理出来的特征。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值