对比学习 Contrast Learning

理论

在这里插入图片描述
监督学习:技术相对成熟,但是对海量的数据进行标记需要花费大量的时间和资源

无监督学习:自主地从大量数据中学习同类数据的相同特性,并将其编码为高级表征,再根据不同任务进行微调即可,节省时间以及硬件资源。

生成式学习
生成式学习以自编码器(例如GAN,VAE等等)这类方法为代表,由数据生成数据,使之在整体或者高级语义上与训练数据相近。

对比式学习
对比式学习着重于学习同类实例之间的共同特征,区分非同类实例之间的不同之处。

与生成式学习比较,对比式学习不需要关注实例上繁琐的细节,只需要在抽象语义级别的特征空间上学会对数据的区分即可,因此模型以及其优化变得更加简单,且泛化能力更强。

在这里插入图片描述


用聚类的思想来理解:
在这里插入图片描述
d ( f ( x ) , f ( x + ) ) ≪ d ( f ( x ) , f ( x − ) ) O R s ( f ( x ) , f ( x + ) ) ≫ s ( f ( x ) , f ( x − ) ) d(f(x),f(x^+))\ll d(f(x),f(x^-)) \\ OR \\ s(f(x),f(x^+))\gg s(f(x),f(x^-)) d(f(x),f(x+))d(f(x),f(x))ORs(f(x),f(x+))s(f(x),f(x))

  • 缩小类内的距离,扩大类外的距离

丈量二者距离:欧几里得距离,余弦相似度,马氏距离 …
目标:给定锚点,通过空间变换,使得锚点与正样本间距离尽可能小,与负样本距离尽可能大,这个应该是triptloss的思想


对比损失

W :网络权重; W :网络权重; W:网络权重;
Y : L a b e l , Y :Label, YLabel
Y = { 0 , X 1 , X 2 同类 1 , X 1 , X 2 不同类 Y= \begin{cases} 0,\quad X_1,X_2同类\\ 1, \quad X_1,X_2不同类 \end{cases}\\ Y={0,X1,X2同类1,X1,X2不同类

D W :是 X 1 与 X 2 在潜变量空间的欧几里德距离。 D_W :是 X_1 与 X_2 在潜变量空间的欧几里德距离。 DW:是X1X2在潜变量空间的欧几里德距离。

i :表示第 i 组向量对。 i :表示第i组向量对。 i:表示第i组向量对。

L :研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。 L :研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。 L:研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。
L ( W , ( Y , X 1 ⃗ , X 2 ⃗ ) i ) = ( 1 − Y ) L S ( D W i ( X 1 ⃗ , X 2 ⃗ ) ) + Y L D ( D W i ( X 1 ⃗ , X 2 ⃗ ) )   L ( W ) = ∑ i = 1 P L ( W , ( Y , X 1 ⃗ , X 2 ⃗ ) i ) L(W,(Y,\vec{X_1},\vec{X_2})^i)=(1-Y)L_S(D_W^i(\vec{X_1},\vec{X_2}))+YL_D(D_W^i(\vec{X_1},\vec{X_2}))\\\ L(W)=\sum^P_{i=1}L(W,(Y,\vec{X_1},\vec{X_2})^i) \\ L(W,(Y,X1 ,X2 )i)=(1Y)LS(DWi(X1 ,X2 ))+YLD(DWi(X1 ,X2 )) L(W)=i=1PL(W,(Y,X1 ,X2 )i)

正样本
当与锚点是正样本时,由于对比思想,二者之间会逐渐靠近。
原文将它假设成一个原长 l → 0 \rightarrow 0 0 的弹簧,那么就会将正样本无限的拉近,从而完成聚类。
F ⃗ = − x ⃗ 将锚点设为势能零点: E = 0 − ∫ F ⃗ d x ⃗ = 1 2 x 2 那么 E 即可作为 L S ,且满足定义要求: L S = 1 2 D W 2 \vec{F}=-\vec{x}\\ 将锚点设为势能零点: E=0-\int\vec{F}d\vec{x}=\frac 1 2 x^2\\ 那么 E 即可作为L_S ,且满足定义要求:L_S=\frac 1 2 D_W^2\\ F =x 将锚点设为势能零点:E=0F dx =21x2那么E即可作为LS,且满足定义要求:LS=21DW2

在这里插入图片描述
负样本

当与锚点是负样本时,由于对比思想,二者之间会逐渐原理。原文将它假设成一个原长 l → m \rightarrow m m 的弹簧,那么就会将负样本至少拉至m,从而完成划分。

F ⃗ = m ⃗ − x ⃗ 将锚点设为势能零点: E = 0 − ∫ F ⃗ d x ⃗ = 1 2 ( m − x ) 2 L D = 1 2 ( m a x { 0 , m − D W } ) 2 \vec{F}=\vec{m}-\vec{x}\\ 将锚点设为势能零点: E=0-\int\vec{F}d\vec{x}=\frac 1 2 (m-x)^2\\ L_D=\frac 1 2 (max\{0,m-D_W\})^2 F =m x 将锚点设为势能零点:E=0F dx =21(mx)2LD=21(max{0,mDW})2
在这里插入图片描述
原定义:
L ( W , Y , X 1 ⃗ , X 2 ⃗ ) = ( 1 − Y ) D W 2 + Y ⋅ 1 2 ( m a x { 0 , m − D W } ) 2 L(W,Y,\vec{X_1},\vec{X_2})=(1-Y)D_W^2+Y\cdot \frac 1 2 (max\{0,m-D_W\})^2\\ L(W,Y,X1 ,X2 )=(1Y)DW2+Y21(max{0,mDW})2
{ 当 Y = 0 ,调整参数最小化​ D W ( X 1 ⃗ , X 2 ⃗ ) 当 Y = 1 ,设二者向量最大距离为 m \begin{cases} 当Y=0,调整参数最小化​ D_W(\vec{X_1},\vec{X_2}) \\ 当Y=1,设二者向量最大距离为m \end{cases}\\ {Y=0,调整参数最小化DW(X1 ,X2 )Y=1,设二者向量最大距离为m
{ 如果​ D W ( X 1 ⃗ , X 2 ⃗ ) < m , 则增大两者距离到 m ; 如果​ D W ( X 1 ⃗ , X 2 ⃗ ) ≥ m ,则不做优化。 \begin{cases}如果​ D_W(\vec{X_1},\vec{X_2})<m , 则增大两者距离到m;\\ 如果​ D_W(\vec{X_1},\vec{X_2})\geq m ,则不做优化。\end{cases} {如果DW(X1 ,X2 )<m,则增大两者距离到m如果DW(X1 ,X2 )m,则不做优化。

效果就是:
在这里插入图片描述

Paper Waitting Read

一些常使用的Constrastive Loss

Triplet Loss:
L = m a x { d ( x , x + ) − d ( x , x − ) + α , 0 } L=max\{d(x,x^+)-d(x,x^-)+\alpha,0\}\\ L=max{d(x,x+)d(x,x)+α,0}

NCE Loss:
之前从向量空间考虑,NCE从概率角度考虑【原证明为贝叶斯派的证法】,NCE是对于得分函数的估计,那也就是说,是对于你空间距离分配的合理性进行估计。

总之NCE通过对比噪声样本与含噪样本,从而推断真实分布。

InfoNCE Loss 互信息:
I ( x , c ) = ∑ x ∑ c p ( x , c ) l o g p ( x , c ) p ( x ) p ( c ) = ∑ x , c p ( x , c ) l o g p ( x ∣ c ) p ( x ) I(x,c)=\sum_x\sum_c p(x,c)log\frac{p(x,c) }{p(x)p(c) } =\sum_{x,c}p(x,c)log\frac{p(x|c)}{p(x)}\\ I(x,c)=xcp(x,c)logp(x)p(c)p(x,c)=x,cp(x,c)logp(x)p(xc)

  1. 互信息上界估计:减少互信息,即VAE的目标。
  2. 互信息下界估计:增加互信息,即对比学习(CL)的目标。【后来也有CLUB上界估计和下界估计一起使用的对比学习。】

最关键的问题:如何构建正实例对和负实例对?

Paper

CPC

很多时候,很多数据维度高、label相对少,我们并不希望浪费掉没有label的那部分data。所以在label少的时候,可以利用无监督学习帮助我们学到数据本身的高级信息,从而对下游任务有很大的帮助。

Contrastive Predictive Coding(CPC) 这篇文章就提出以下方法:

  1. 将高维数据压缩到更紧凑的隐空间中,在其中条件预测更容易建模。
  2. 用自回归模型在隐空间中预测未来步骤。
  3. 依靠NCE来计算损失函数(和学习词嵌入方式类似),从而可以对整个模型进行端到端的训练。
  4. 对于多模态的数据有可以学到高级信息。

可以利用一定窗口内的 x t x_{t} xt x t + k x_{t+k} xt+k作为正实例对,并从输入序列之中随机采样一个输入作为 x t ∗ x_{t*} xt 负实例。

  • 随机采样作为负样本,这个思想很关键!!!
    在这里插入图片描述

给定声音序列上下文 c_t ,由此我们推断预测 x_{t+k} 位置上的声音信号。题目假设,声音序列全程伴随有噪音。

为了将噪音序列与声音序列尽可能的分离编码,这里就随机采样获得 x_{t*} 代替 x_{t+k} 位置信号,作为负样本进行对比学习。

  • 意思就是,原本t+k是正常的数据,但是这是个序列,t是一个窗口,所以在序列有正常的样本,也有异常的样本,但是拿到的数据一般是正常的数据多,异常的数据(噪声)少(但又非常的关键),那这样的话正负样本比例失调(不平衡)而且也学不到正常样本的本质
    - 所以,利用噪声的思想,把正常的样本加上随机的噪声作为负样本,这样来学正样本的规律和本质。也就是说,负样本我并不关心,它只是一个参照一个背景板,让模型去学正样本的本质规律

回到这个例子:
在这里插入图片描述

首先我们在原信号上选取一些时间窗口,对每一个窗口,通过encoder g e n c g_{enc} genc ,得到表示向量 z t z_t zt

z t z_t zt 通过自回归模型: g a r g_{ar} gar ,从而生成上下文隐变量 c t c_t ct

然后通过Bi-linear:

  • 采用 c t c_t ct z t + k z_{t+k} zt+k 从而能够压缩高维数据,并且计算 c t c_t ct z t + k z_{t+k} zt+k 的未来值是否符合

f k ( x t + k , c t ) = exp ⁡ ( z t + k T ( W k c t ) ) f_k(x_{t+k},c_t)=\exp(z^T_{t+k}(W_kc_t))\\ fk(xt+k,ct)=exp(zt+kT(Wkct))


SimCLR

A Simple Framework for Contrastive Learning of Visual Representations

simCLR背后的想法非常简单:

  • 视觉表征对于同一目标不同视角的输入都应具有不变性。

simCLR对输入的图片进行数据增强,以此来模拟图片不同视角下的输入。之后采用对比损失最大化相同目标在不同数据增强下的相似度,并最小化同类目标之间的相似度。

在这里插入图片描述

simCLR的架构由两个相同的网络模块组成。对于每一个输入网络的minibatch:

  1. 对mini batch中每张输入的图片进行两次随机数据增强(随机剪裁、滤镜、颜色过滤、灰度化等)来得到图片两种不同的视角;
  2. 将得到的两个表征送入两个卷积编码器(如resnet)获得抽象表示,之后对这些表示形式应用非线性变换进行投影得到投影表示;
  3. 使用余弦相似度来度量投影的相似度

文章参考链接(综述):https://zhuanlan.zhihu.com/p/346686467

Papar: https://zhuanlan.zhihu.com/p/363900943


代码实战:main_linear

  • Resnet + Classifier + CELoss
def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = torch.nn.CrossEntropyLoss()

    classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)

    ckpt = torch.load(opt.ckpt, map_location='cpu')
    state_dict = ckpt['model']

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        else:
            new_state_dict = {}
            for k, v in state_dict.items():
                k = k.replace("module.", "")
                new_state_dict[k] = v
            state_dict = new_state_dict
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

        model.load_state_dict(state_dict)
    else:
        raise NotImplementedError('This code requires GPU')

    return model, classifier, criterion


def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.eval()
    classifier.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        with torch.no_grad():
            features = model.encoder(images)
        output = classifier(features.detach())
        loss = criterion(output, labels)

        # update metric
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg

main_supcon

def set_model(opt):
    model = SupConResNet(name=opt.model)
    criterion = SupConLoss(temperature=opt.temp)

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion


def train(train_loader, model, criterion, optimizer, epoch, opt):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        # 使用torch.cat函数将切分后的两个子特征f1和f2在第一个维度上进行拼接,即将它们作为两个
        # 通道(unsqueeze(1))拼接在一起,得到最终的特征features
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        
        if opt.method == 'SupCon':
            loss = criterion(features, labels)
        elif opt.method == 'SimCLR':
            loss = criterion(features)
        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(opt.method))

        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg
  • main函数
def main():
    opt = parse_option()

    # build data loader
    train_loader = set_loader(opt)

    # build model and criterion
    model, criterion = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, epoch, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('loss', loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)

写的很好的utils


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    return optimizer


def save_model(model, optimizer, opt, epoch, save_file):
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state

Loss

对比损失:Supervised Contrastive Loss(监督对比损失)是一种在监督对比学习中使用的损失函数。它旨在学习既具有区分性又具有对同一类别内变化具有不变性的表示。

监督对比学习的目标是最大化正样本对(同一类别的样本)的一致性,并最小化负样本对(不同类别的样本)的一致性。监督对比损失通过鼓励正样本对的表示在嵌入空间中更加接近,同时将负样本对的表示推开来实现这一目标。

  • If both labels and mask are None, it degenerates to SimCLR unsupervised loss:
 Args:
     features: hidden vector of shape [bsz, n_views, ...].
     labels: ground truth of shape [bsz].
     mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
           has the same class as sample i. Can be asymmetric.
Returns:
     A loss scalar

class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

Net

model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}

class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128):
        super(SupConResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat


class SupCEResNet(nn.Module):
    """encoder + classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(SupCEResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, x):
        return self.fc(self.encoder(x))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值