创建memory bank(queue实现)工具类

# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
# This source code is licensed under the LICENSE file in the root directory of this source tree.

##Doh: Added modification to support mixup

import torch


class XBM:
    def __init__(self, args, device):
        self.K = args.xbm_per_class*args.num_classes*2 # We want to store a minimum number of samples per-class. x2 due to the augmented views
        #self.feats = torch.zeros(self.K, args.low_dim).to(device) # original
        #self.targets = torch.zeros(self.K, dtype=torch.long).to(device) # original
        self.feats = -1.0 * torch.ones(self.K, args.low_dim).to(device) # doh
        self.targets = -1.0 * torch.ones(self.K, dtype=torch.long).to(device) # doh

        self.ptr = 0

    @property
    def is_full(self):
        #return self.targets[-1].item() != 0 #original
        return self.targets[-1].item() != -1 #doh

    def get(self):
        if self.is_full:
            return self.feats, self.targets

        else:
            return self.feats[:self.ptr], self.targets[:self.ptr]


    def enqueue_dequeue(self, feats, targets):
        q_size = len(targets)

        if self.ptr + q_size > self.K:
            self.feats[-q_size:] = feats
            self.targets[-q_size:] = targets
            self.ptr = 0
        else:
            self.feats[self.ptr: self.ptr + q_size] = feats
            self.targets[self.ptr: self.ptr + q_size] = targets
            self.ptr += q_size

上面的为MOIT的实现方法,创建一下XBM工具类

使用该工具类能够生成一个memorybank,由队列实现,从而增大对比学习的正负样本数量,不会受限与batch-size。

moco的实现,通过直接在model中加入queue实现

class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值