通用域适应(三)Prototypical Partial Optimal Transport for Universal Domain Adaptation

前言

Introduction

通用领域自适应(UniDA)要求我们在减小领域差距之前区分两个领域中的“已知”样本(即标签存在于两个领域的样本)和“未知”样本(即标签只存在于一个领域的样本)。
本文从分布匹配的角度考虑这一问题,只需要对两个分布进行部分对齐即可。提出了一种新的方法,称为小批量原型部分最优传输(m-PPOT),对UniDA进行部分分配对齐。在训练阶段,除了最小化m-PPOT外,我们还利用m-PPOT的传输计划对源原型和目标样本进行重加权,并设计重加权熵损失和重加权交叉熵损失来区分“已知”和“未知”样本。
对原型部分最优传输(PPOT)方法具体来说,我们将UniDA中的分布对齐建模为一个部分最优传输(POT)问题,使用POT在两个域之间对齐一小部分数据(主要来自公共类)。我们设计了一个基于原型的POT,其中源数据表示为POT公式中的原型,该原型进一步表述为基于小批量的版本,称为m-PPOT。
我们证明了POT可以被m-PPOT和源样本与它们对应的原型之间的距离所限定,这启发了我们将m-PPOT作为一个训练损失来设计UniDA的深度学习模型。同时,m-ppot的运输计划可以看作是一个匹配矩阵,我们可以利用运输计划的行和和列和对源原型和目标样本进行重加权,以区分“已知”和“未知”样本。在m-ppot传输计划的基础上,进一步设计源标记数据的重加权交叉熵损失和目标数据的重加权熵损失,学习可转移的识别模型。
行为源原型,列为目标特征
Mini-batch OT旨在降低计算成本,使OT更适合深度学习。我们将b个随机样本的经验分布集合表示为{xsi}mi=1 (resp。{xtj}nj=1) as Pb(µ)(resp;Pb(ν)),其中b是batch_size大小,k是mini_batch的数量。小批量OT定义为:
在这里插入图片描述
Partial OT旨在以最低的成本在µ和ν之间仅传输α质量(0<=α<=min(||µ||1, ||ν||1))。部分OT定义为:
在这里插入图片描述
在这里插入图片描述
(行和指的是对一行求和,列和指的是对一列求和) (行对应的源原型,列对应的目标特征)
图1:我们模型的说明。源数据和目标数据共享相同的特征提取器,将数据嵌入特征空间。PPOT是将目标特征与通过源特征更新后的源原型进行匹配,并利用运输计划的行/列和进行重加权。我们设计了重加权熵损失来对齐两个域的公有类特征,同时排除未知特征。

Method

UniDA的目的是用源类集Cs中的标签标记目标样本,或者将其区分为“未知”样本。我们表示源域类的数量为L = |Cs|。我们的深度识别模型由两个模块组成,包括将输入x映射到特征z的特征提取器 f 和L-way分类头 h(L为源类数量)。特征空间中的源和目标经验分布分别表示为:
在这里插入图片描述
在这里插入图片描述
稍微滥用一下符号,我们将数据质量向量表示为-p和-q
在这里插入图片描述
Modeling UniDA as Partial OT
直接对齐源分布¯p和目标分布¯q将导致数据不匹配,因为在两个域中都存在“未知私有”样本。对于UniDA任务,我们首先将¯p,¯q分解为
在这里插入图片描述
其中pp(对应qp)表示特征空间中源私有类数据的分布(对应目标),pc和qc分别表示源域和目标域的公共类数据分布,β和α分别表示源域和目标域的公共类样本的比例。我们的目标是尽量减少pc和qc之间的差异,这是一个OT问题:
alpha为目标样本公共类比例
在这里插入图片描述
两个域之间的部分传输将倾向于传输它们的公共类样本。因此,我们近似解出Eqn。(4)通过优化eq(5)
在这里插入图片描述
其中系数(α/β)是为了确保¯p和¯q中的公共类样本的质量相等。(α/β)·¯p和¯q分别记为p和q。
Prototypical Partial Optimal Transport
我们已经把pc和qc之间的分布对齐变成了Eqn(5)的partial OT问题。
在这里插入图片描述
在这里插入图片描述
行为源原型 列为目标样本
在这里插入图片描述
m-PPOT定义为:
在这里插入图片描述
其中Γ为{1,2,…, n}的子集,即目标数据的索引集。
Theorem 1.考虑两个分布p和q, f(xsi)与相应原型cyi之间的距离记为di =d( f(xsi), cyi)。PPOTα(p, q)最优运输计划的行的和记为w = (w1, w2,…wL)T,在这里插入图片描述
。然后我们有
在这里插入图片描述

UniDA Based on m-PPOT

我们的动机是最小化源和目标公共类数据分布之间的差异,同时在训练中分离两个领域的“已知”和“未知”数据。我们为训练设计了以下损失。
**m-PPOT Loss.**在定理1的基础上,为了最小化pc和qc之间的差异,我们首先设计m- PPOT损失以最小化定理1界内的第二项。我们引入m-PPOTαB(p, q)作为损失:
在这里插入图片描述
使用基于小批量的优化方法,根据Eqn.(7),这一项可以近似为每个小批量上的部分OT问题POTa (c, qBi)。原型集c通过指数移动平均更新.
Reweighted Entropy Loss.我们进一步在目标域数据上设计基于熵的损失来提高预测的确定性。m-PPOTαB(p, q)的解π∗是测量源原型与目标特征之间匹配的矩阵。由于原型(特征)越容易被转移,它就越有可能属于一个公有的类别(“已知”样本),我们利用π 的行/列和作为识别未知样本的指标。具体来说,我们首先得到π 的列和,并乘以常数n/α,使wt∈Rn满足||wt||1 = n。重新加权的熵损失表示为:
在这里插入图片描述
在这里插入图片描述
我们利用这一损失来增加那些被视为“已知”样本的目标样本的预测置信度。
此外,我们通过loss 抑制模型对目标“未知”样本产生过度自信的预测:
在这里插入图片描述
在这里插入图片描述
样本的wui越高,意味着属于“未知”样本的置信度越高。因此,我们使用
Lne来降低可能是“未知”样本的样本的置信度。

**加权交叉熵损失。**这种损失是在源域中定义的分类损失,基于使用源域数据标签的交叉熵。与标准分类损失不同,我们使用π *的列和来计算权值ws∈RL来度量“已知”源域原型的置信度。然后设计了加权交叉熵损失:
在这里插入图片描述
在这里插入图片描述
权重满足在这里插入图片描述每个权值代表每个类别属于一个共同类别的可能性。基于交叉熵的损失可以最小化特征到类原型的距离。这意味着重新加权的交叉熵损失近似地最小化定理1界中的第一项,在定理1中,我们使用m-PPOT的行和ws来近似PPOT的行和w,并在实现中使用类平衡抽样来强制rj,∀j相等。
在这里插入图片描述
在这里插入图片描述
训练过程如图1所示。
(1)首先,我们通过特征提取器将两个域的数据映射到特征空间中。根据每批的源特征更新源原型.
(2)然后计算源原型与目标样本经验分布之间的m-PPOT,并利用相应运输计划的行和和列和对损失进行重加权。
Lot的目的是缩小两个域中“已知”样本分布之间的差距,**Lent通过降低目标域中“已知”样本的熵来增强其预测置信度,而通过增加目标域中“未知”样本的熵来增强其预测置信度。**由于分类器是在源领域数据上学习的,这可能会使“已知”目标领域数据与源领域数据分布保持一致,同时将“未知”目标领域数据推离源领域数据分布。
超参数
α和β在实际中几乎不可能精确计算,我们提出了一种近似计算它们的方法。
两个标量为τ1和τ2,其中τ1∈(0,1】,τ2 >0。
为了简化符号,我们使用s(x) = max σ(h◦f(x))来表示x的预测置信度。我们定义α和β为:
在这里插入图片描述
动机是我们使用高置信度样本的比例来估计目标域中“已知”样本的比例,类似地使用高权重类别的比例来近似源域中公有类别的比例。
beta
在实验中,我们设τ1 = 0.9, τ2 = 1。在训练阶段的第i次迭代中,我们首先用Eqn(15)计算αi,用指数移动平均更新α:
在这里插入图片描述
然后我们用αi作为传输比率,αii−1作为Eqn(5)的系数来计算Lot和副产品ws。然后我们用Eqn(15)计算βi并将其像更新αi一样更新βi:
beta为源域样本公共类比例
在这里插入图片描述
其中λ1, λ2∈[0,1]在我们的实验中设为0.001。
此外,为了减少将“已知”样本识别为“未知”样本的可能错误,我们只保留{wui}ni=1中具有较大值的部分,并将其他部分设置为0。在所有任务中,fraction设置为25%。

Experiment

RTX A6000 48G显存
继之前的作品(Saito and Saenko 2021;Chen et al . 2022),我们使用没有最后一个完全连接层的ResNet50 (He et al . 2016)作为我们的特征提取器。在特征提取器之后依次添加256维瓶颈层和预测头h。
我们使用MocoV2对我们的特征提取器进行对比预训练,预训练的epoch数为100,batch大小为256,学习率为0.03。
在训练阶段,我们使用Nesterov动量SGD优化模型,动量为0.9,权重衰减为5×10−4。接下来(Ganin and Lempitsky 2015),学习率随因子(1 + αt)−β衰减,其中 t 在训练中从0到1线性变化,我们设置α = 10 (lr_gamma), β = 0.75(lr_decay)。
在所有实验中,批大小都设置为72,除了DomainNet任务,它被更改为256。我们训练了5个epoch(每个epoch 1000次迭代)的模型,并在每个epoch之前完全更新源原型和α
初始学习率在Office-31上设置为1 × 10−4,在Office-Home和VisDA上设置为5 × 10−4,在DomainNet上设置为0.01。

code

train.py

import argparse
import copy
import os
import random
import sys
import time
import ot
import numpy as np
import torch
import torch.multiprocessing
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transform
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from utils import datasets
from utils.meter import AverageMeter, ProgressMeter
from utils.sampler import BalancedBatchSampler
from utils.logger import TextLogger
from utils import h_score, entropy_loss, get_prototypes
from modules.resnet import Res50
from moco.moco import train_moco

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
torch.multiprocessing.set_sharing_strategy(‘file_system’)

def main(args: argparse.Namespace):
# create logger
now = time.strftime(“%Y-%m-%d-%H_%M_%S”, time.localtime(time.time()))
filename = os.path.join(‘log/’, “{}2{}-{}.txt”.format(args.source, args.target, now))
logger = TextLogger(filename)
sys.stdout = logger
sys.stderr = logger

if args.task == 'office31':
    args.common_class = 10
    args.source_private_class = 10
    args.target_private_class = 11
    args.moco_k = 1024
if args.task == 'VisDA2017':
    args.common_class = 6
    args.source_private_class = 3
    args.target_private_class = 3
    args.moco_k = 65536
if args.task == 'officehome':
    args.common_class = 10
    args.source_private_class = 5
    args.target_private_class = 50
    args.moco_k = 3072
if args.task == 'DomainNet':
    args.batch_size = 256
    args.common_class = 150
    args.source_private_class = 50
    args.target_private_class = 145
    args.moco_k = 65536
print(args)

# create seed
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
cudnn.deterministic = True
cudnn.enabled = True

# create transform
train_transform = transform.Compose([
    transform.Resize(256),
    transform.RandomResizedCrop(224),
    transform.RandomHorizontalFlip(),
    transform.ToTensor(),
    transform.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transform.Compose([
    transform.Resize(256),
    transform.CenterCrop(224),
    transform.ToTensor(),
    transform.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# create data set
source_dataset = datasets.Datasets(args, source=True, transform=train_transform)
if args.balanced:
    source_label = source_dataset.create_label_set()
    train_batch_sampler = BalancedBatchSampler(source_label, batch_size=args.batch_size)
    source_loader = DataLoader(source_dataset, batch_sampler=train_batch_sampler, num_workers=args.num_workers)
else:
    source_loader = DataLoader(dataset=source_dataset, batch_size=args.source_batch_size, shuffle=True,
                               num_workers=args.num_workers, drop_last=True)

target_dataset = datasets.Datasets(args, source=False, transform=train_transform)
target_loader = DataLoader(dataset=target_dataset, batch_size=args.batch_size, shuffle=True,
                           num_workers=args.num_workers, drop_last=True)

val_dataset = datasets.Datasets(args, source=False, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

# create model
num_class = args.common_class + args.source_private_class
if args.checkpoint:
    model = Res50(num_class=num_class, checkpoint=args.checkpoint).to(device)
    print("use checkpoint")
elif args.no_ssl:
    model = Res50(num_class=num_class).to(device)
else:
    train_moco(args)
    model = Res50(num_class=num_class,
                  checkpoint='checkpoint/{}2{}_{:04d}.pth.tar'.format(args.source, args.target, args.moco_epochs)
                  ).to(device)

# create optimizer
all_parameters = model.get_parameters()
optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=10 * args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

# train source domain 1000 iterations to fine-tune the model
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.3f')
progress = ProgressMeter(
    args.pre_step,
    [batch_time, data_time, losses, top1],
    prefix="Epoch: [{}]".format(-1))

model.train()
if args.new_opt:
    # use an independent optimizer to fine-tune (optional)
    new_optimizer = SGD(all_parameters, args.pre_lr, momentum=args.momentum, weight_decay=args.weight_decay,
                        nesterov=True)
    new_lr_scheduler = LambdaLR(new_optimizer,
                                lambda x: args.pre_lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
source_iter = iter(source_loader)
end = time.time()
for step in range(args.pre_step):
    try:
        source_data = next(source_iter)
    except StopIteration:
        source_iter = iter(source_loader)
        source_data = next(source_iter)
    s_img, s_label = source_data
    s_img = s_img.to(device)
    s_label = s_label.to(device)
    data_time.update(time.time() - end)
    s_prediction, s_feature = model(s_img)

    loss = F.cross_entropy(s_prediction, s_label)

    if args.new_opt:
        new_optimizer.zero_grad()
        loss.backward()
        new_optimizer.step()
        new_lr_scheduler.step()
    else:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    batch_time.update(time.time() - end)
    end = time.time()

    with torch.no_grad():
        pred = torch.argmax(s_prediction, dim=1)
        correct = pred.eq(s_label).sum().item()
        total_correct = correct
        total_samples = s_label.size(0)
        acc = total_correct / total_samples
        top1.update(acc, s_img.size(0))
        losses.update(loss, s_img.size(0))
        if (step + 1) % 100 == 0:
            progress.display(step + 1)

# start training
best_h_score = 0.
best_unknown_acc = 0.
best_known_acc = 0.

validate(val_loader, model, epoch=-1)

# init source prototypes, alpha, beta and class weight
s_feat, s_label = get_features(source_loader, model)
s_protos = get_prototypes(s_feat, s_label, args)
alpha = update_alpha(target_loader, model)
beta = alpha
class_weight = torch.ones(num_class)

for epoch in range(args.epochs):
    # train one epoch
    class_weight, beta = train(source_loader, target_loader, s_protos, class_weight, model, optimizer, lr_scheduler,
                               alpha, beta, epoch)

    # update source prototypes and alpha
    s_feat, s_label = get_features(source_loader, model)
    s_protos = get_prototypes(s_feat, s_label, args)
    alpha = update_alpha(target_loader, model)

    # evaluate on validation set
    h_scores, k_acc, u_acc = validate(val_loader, model, epoch)

    # remember best acc@1 and save checkpoint
    if h_scores * 100 > best_h_score:
        best_h_score = h_scores * 100
        best_unknown_acc = u_acc * 100
        best_known_acc = k_acc * 100

print("best H-score = {:3.1f}".format(best_h_score))
print("best unknown accuracy = {:3.1f}".format(best_unknown_acc))
print("best known accuracy = {:3.1f}".format(best_known_acc))

def train(source_loader: DataLoader, target_loader: DataLoader, source_prototype: torch.Tensor,
class_weight: torch.Tensor, model: Res50, optimizer: SGD, lr_scheduler: LambdaLR,
alpha: np.ndarray, beta: np.ndarray, epoch: int) -> [torch.Tensor, np.ndarray]:

batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
cls_losses = AverageMeter('C_Loss', ':.2e')
ot_losses = AverageMeter('OT_Loss', ':.2e')
progress = ProgressMeter(
    args.step,
    [batch_time, cls_losses, ot_losses],
    prefix="Epoch: [{}]".format(epoch))
num_classes = args.common_class + args.source_private_class
model.train()

source_data_iter = iter(source_loader)
target_data_iter = iter(target_loader)
end = time.time()
for step in range(args.step):
    try:
        source_data = next(source_data_iter)
    except StopIteration:
        source_data_iter = iter(source_loader)
        source_data = next(source_data_iter)
    try:
        target_data = next(target_data_iter)
    except StopIteration:
        target_data_iter = iter(target_loader)
        target_data = next(target_data_iter)

    s_img, s_label = source_data
    t_img, _ = target_data
    s_img, t_img = s_img.to(device), t_img.to(device)
    s_label = s_label.to(device)
    class_weight = class_weight.to(device)
    data_time.update(time.time() - end)
    s_prediction, s_feature = model(s_img)
    s_feature = F.normalize(s_feature, p=2, dim=-1)
    _, t_feature = model(t_img)
    
    # freeze head parameters
    head = copy.deepcopy(model.head)
    for params in head.parameters():
        params.requires_grad = False
    t_prediction = F.softmax(head(t_feature), dim=1)
    conf, pred = t_prediction.max(dim=1)
    t_feature = F.normalize(t_feature, p=2, dim=-1)
    batch_size = t_feature.shape[0]

    # update alpha by moving average
    alpha = (1 - args.alpha) * alpha + args.alpha * (conf >= args.tau1).sum().item() / conf.size(0)

    # get alpha / beta
    match = alpha / beta

    # update source prototype by moving average
    source_prototype = source_prototype.data.to(device)
    batch_source_prototype = torch.zeros_like(source_prototype).to(device)
    for i in range(num_classes):
        if (s_label == i).sum().item() > 0:
            batch_source_prototype[i] = (s_feature[s_label == i].mean(dim=0))
        else:
            batch_source_prototype[i] = (source_prototype[i])
    source_prototype = (1 - args.tau) * source_prototype + args.tau * batch_source_prototype
    source_prototype = F.normalize(source_prototype, p=2, dim=-1)

    # get ot loss
    a, b = match * ot.unif(num_classes), ot.unif(batch_size)
    m = torch.cdist(source_prototype, t_feature) ** 2
    m_max = m.max().detach()
    m = m / m_max
    pi, log = ot.partial.entropic_partial_wasserstein(a, b, m.detach().cpu().numpy(), reg=args.reg, m=alpha,
                                                      stopThr=1e-10, log=True)
    pi = torch.from_numpy(pi).float().to(device)
    ot_loss = torch.sqrt(torch.sum(pi * m) * m_max)
    loss = args.ot * ot_loss

    # update class weight and target weight by plan pi
    plan = pi * batch_size
    k = round(args.neg * batch_size)
    min_dist, _ = torch.min(m, dim=0)
    _, indicate = min_dist.topk(k=k, dim=0)
    batch_class_weight = torch.tensor([plan[i, :].sum() for i in range(num_classes)]).to(device)
    class_weight = args.tau * batch_class_weight + (1 - args.tau) * class_weight
    class_weight = class_weight * num_classes / class_weight.sum()
    k_weight = torch.tensor([plan[:, i].sum() for i in range(batch_size)]).to(device)
    k_weight /= alpha
    u_weight = torch.zeros(batch_size).to(device)
    u_weight[indicate] = 1 - k_weight[indicate]

    # update beta
    beta = args.beta * (class_weight > args.tau2).sum().item() / num_classes + (1 - args.beta) * beta

    # get classification loss
    cls_loss = F.cross_entropy(s_prediction, s_label, weight=class_weight.float())
    loss += cls_loss

    # get entropy loss
    p_ent_loss = args.p_entropy * entropy_loss(t_prediction, k_weight)
    n_ent_loss = args.n_entropy * entropy_loss(t_prediction, u_weight)
    ent_loss = p_ent_loss - n_ent_loss
    loss += ent_loss

    # compute gradient
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

    batch_time.update(time.time() - end)
    end = time.time()

    with torch.no_grad():

        cls_losses.update(cls_loss, batch_size)
        ot_losses.update(args.ot * ot_loss, batch_size)

        if (step + 1) % args.interval == 0:
            progress.display(step + 1)
return class_weight, beta

def validate(val_loader: DataLoader, model: Res50, epoch: int) -> [np.ndarray]:
data_time = AverageMeter(‘Time’, ‘:6.3f’)
k_top1 = AverageMeter(‘Known’, ‘:6.3f’)
u_top1 = AverageMeter(‘Unknown’, ‘:6.3f’)
h_sco = AverageMeter(‘H-score’, ‘:6.3f’)
progress1 = ProgressMeter(
len(val_loader),
[data_time, k_top1, u_top1, h_sco],
prefix=“Epoch: [{}]”.format(epoch))

model.eval()
total_correct = 0
total_samples = 0
total_unknown_correct = 0
total_unknown_samples = 0
unknown_label = args.common_class + args.source_private_class
end = time.time()
with torch.no_grad():
    for step, (image, label) in enumerate(val_loader):
        image = image.to(device)
        label = label.to(device)
        output, feature = model(image)
        softmax = nn.Softmax(dim=1)
        output = softmax(output)

        # discriminate unknown sample by confidence
        confidence, pred = output.max(dim=1)
        pred[confidence < args.threshold] = unknown_label
        correct = pred.eq(label).sum().item()
        unknown_correct = pred[confidence < args.threshold].eq(label[confidence < args.threshold]).sum().item()
        unknown_sample = (label == unknown_label).sum().item()
        total_correct += correct
        total_unknown_correct += unknown_correct
        total_samples += label.shape[0]
        total_unknown_samples += unknown_sample

    # compute H-score and accuracy
    known_acc = (total_correct - total_unknown_correct) / (total_samples - total_unknown_samples)
    unknown_acc = total_unknown_correct / total_unknown_samples
    h_scores = h_score(known_acc, unknown_acc)
    data_time.update(time.time() - end)
    k_top1.update(known_acc)
    u_top1.update(unknown_acc)
    h_sco.update(h_scores)
    progress1.display(len(val_loader))
return h_scores, known_acc, unknown_acc

def update_alpha(target_loader: DataLoader, model: Res50) -> np.ndarray:
num_conf, num_sample = 0, 0
model.eval()

with torch.no_grad():
    for _, (img, _) in enumerate(target_loader):
        img = img.to(device)
        output, _ = model(img)
        output = F.softmax(output, dim=1)
        conf, _ = output.max(dim=1)
        num_conf += torch.sum(conf > args.tau1).item()
        num_sample += output.shape[0]

    alpha = num_conf / num_sample
    alpha = np.around(alpha, decimals=2)
return alpha

def get_features(data_loader: DataLoader, model: Res50) -> [torch.Tensor, torch.Tensor]:
feature_set = []
label_set = []
model.eval()
with torch.no_grad():
for _, (img, gt) in enumerate(data_loader):
img = img.to(device)
_, feature = model(img)
feature_set.append(feature)
label_set.append(gt)
feature_set = torch.cat(feature_set, dim=0)
feature_set = F.normalize(feature_set, p=2, dim=-1)
label_set = torch.cat(label_set, dim=0)
return feature_set, label_set

if name == ‘main’:
parser = argparse.ArgumentParser(description=‘PPOT for Universal Domain Adaptation’)
parser.add_argument(‘–root’, default=‘/path/to/your/dataset/’,
help=‘root of data file’)
parser.add_argument(‘–task’, default=‘officehome’,
help=‘task name’)
parser.add_argument(‘-s’, ‘–source’, default=‘Art’,
help=‘source domain’)
parser.add_argument(‘-t’, ‘–target’, default=‘Clipart’,
help=‘target domain’)
parser.add_argument(‘–common-class’, default=10, type=int,
help=‘number of common class’)
parser.add_argument(‘–source-private-class’, default=5, type=int,
help=‘number of source private class’)
parser.add_argument(‘–target-private-class’, default=50, type=int,
help=‘number of target private class’)
parser.add_argument(‘-b’, ‘–batch-size’, default=72, type=int,
help=‘mini-batch size’)
parser.add_argument(‘-n’, ‘–num-workers’, default=4, type=int,
help=‘number of data loading workers’)
parser.add_argument(‘–lr’, default=0.001, type=float,
help=‘initial learning rate’)
parser.add_argument(‘–momentum’, default=0.9, type=float,
help=‘momentum’)
parser.add_argument(‘-wd’, ‘–weight-decay’, default=0.001, type=float,
help=‘weight dacay’)
parser.add_argument(‘–lr-gamma’, default=0.001, type=float,
help=‘parameter for lr scheduler’)
parser.add_argument(‘–lr-decay’, default=0.75, type=float,
help=‘parameter for lr scheduler’)
parser.add_argument(‘–reg’, default=0.01, type=float,
help=‘regularization term of partial entropy optimal transport’)
parser.add_argument(‘–epochs’, default=5, type=int,
help=‘number of training epochs’)
parser.add_argument(‘–step’, default=1000, type=int,
help=‘number of iterations per epoch’)
parser.add_argument(‘–interval’, default=100, type=int,
help=‘print frequency’)
parser.add_argument(‘–threshold’, default=0.75, type=float,
help=‘confidence threshold of known samples’)
parser.add_argument(‘–pre-step’, default=1000, type=int,
help=‘number of iterations in fine-tune step’)
parser.add_argument(‘–pre-lr’, default=0.0005, type=float,
help=‘initial learning rate in fine-tune step, only work when use --new-opt’)
parser.add_argument(‘–p-entropy’, default=0.01, type=float,
help=‘hyper-parameter of positive entropy loss’)
parser.add_argument(‘–n-entropy’, default=2, type=float,
help=‘hyper-parameter of negative entropy loss’)
parser.add_argument(‘–ot’, default=5, type=float,
help=‘hyper-parameter of ot loss’)
parser.add_argument(‘–neg’, default=0.25, type=float,
help=‘ratio of samples in target domain to compute negative entropy loss’)
parser.add_argument(‘–seed’, default=1024, type=int,
help=‘seed for initializing training. ‘)
parser.add_argument(’–checkpoint’, default=‘’,
help=‘root of network checkpoint’)
parser.add_argument(‘–tau1’, default=0.9, type=float,
help=‘threshold of high confidence in updating alpha’)
parser.add_argument(‘–tau2’, default=1, type=float,
help=‘threshold of known class in updating beta’)
parser.add_argument(‘–tau’, default=0.1, type=float,
help=‘update ratio of source prototype’)
parser.add_argument(‘–alpha’, default=0.001, type=float,
help=‘update ratio of alpha’)
parser.add_argument(‘–beta’, default=0.01, type=float,
help=‘update ratio of beta’)
parser.add_argument(‘–new-opt’, action=‘store_true’,
help=‘use new optimizer in fine-tune step’)
parser.add_argument(‘–balanced’, action=‘store_true’,
help=‘use balanced batch sampler in our experiment’)
parser.add_argument(‘–no-ssl’, action=‘store_true’,
help=‘if you do not want to pre-train network by self-supervised learning’)

# moco configs
parser.add_argument('--moco-epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--mlr', '--moco-learning-rate', default=0.03, type=float,
                    metavar='LR', help='initial learning rate', dest='mlr')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
                    help='learning rate schedule (when to drop lr by 10x)')
parser.add_argument('--mwd', '--moco-weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--world-size', default=1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=0, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true', default='True',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

# moco specific configs:
parser.add_argument('--moco-dim', default=128, type=int,
                    help='feature dimension (default: 128)')
parser.add_argument('--moco-k', default=2304, type=int,
                    help='queue size; number of negative keys (default: 65536)')
parser.add_argument('--moco-m', default=0.999, type=float,
                    help='moco momentum of updating key encoder (default: 0.999)')
parser.add_argument('--moco-t', default=0.2, type=float,
                    help='softmax temperature (default: 0.07)')

# options for moco v2
parser.add_argument('--mlp', action='store_true', default='True',
                    help='use mlp head')
parser.add_argument('--aug-plus', action='store_true', default='True',
                    help='use moco v2 data augmentation')
parser.add_argument('--cos', action='store_true', default='True',
                    help='use cosine lr schedule')
parser.add_argument('--freq', default=200, type=int,
                    metavar='N', help='save frequency (default: 10)')
parser.add_argument('--moco-batch-size', default=256, type=int)

args = parser.parse_args()
main(args)

init.py

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from modules.resnet import Res50
import numpy as np

def get_prototypes(feature_set: torch.Tensor, label_set: torch.Tensor, args) -> torch.Tensor:
class_set = [i for i in range(args.common_class + args.source_private_class)]
source_prototype = torch.zeros(len(class_set), 256)
for i in class_set:
source_prototype[i] = feature_set[label_set == i].sum(0) / feature_set[label_set == i].size(0)
return source_prototype

def h_score(acc_known: float, acc_unknown: float) -> float:
h_scores = 2 * acc_known * acc_unknown / (acc_known + acc_unknown)
return h_scores

def entropy_loss(prediction: torch.Tensor, weight=torch.zeros(1)):
if weight.size(0) == 1:
entropy = torch.sum(-prediction * torch.log(prediction + 1e-8), 1)
entropy = torch.mean(entropy)
else:
entropy = torch.sum(-prediction * torch.log(prediction + 1e-8), 1)
entropy = torch.mean(weight * entropy)
return entropy

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值