oimloss pytorch 1.7

from __future__ import absolute_import

import torch
import torch.nn.functional as F
from torch import nn, autograd


# class OIM(autograd.Function):
#     def __init__(self, lut, momentum=0.5):
#         super(OIM, self).__init__()
#         self.lut = lut  # torch.Size([625, 128])
#         self.momentum = momentum
#
#     def forward(self, inputs, targets):
#         self.save_for_backward(inputs, targets)   # inputs: torch.Size([64, 128])
#         outputs = inputs.mm(self.lut.t())  # (64, 128) * (128, 625)
#         return outputs  # torch.Size([64, 625])
#
#     def backward(self, grad_outputs):
#         inputs, targets = self.saved_tensors
#         grad_inputs = None
#         if self.needs_input_grad[0]:
#             grad_inputs = grad_outputs.mm(self.lut)
#         for x, y in zip(inputs, targets):
#             self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x
#             self.lut[y] /= self.lut[y].norm()
#         return grad_inputs, None

class OIM(autograd.Function):

    @staticmethod
    def forward(ctx, inputs, targets, lut, momentum=0.5):
        ctx.lut = lut  # torch.Size([625, 128])
        ctx.momentum = momentum
        ctx.save_for_backward(inputs, targets)   # inputs: torch.Size([64, 128])
        outputs = inputs.mm(ctx.lut.t())  # (64, 128) * (128, 625)
        return outputs  # torch.Size([64, 625])

    @staticmethod
    def backward(ctx, grad_outputs):
        inputs, targets = ctx.saved_tensors
        grad_inputs = None
        if ctx.needs_input_grad[0]:
            grad_inputs = grad_outputs.mm(ctx.lut)
        for x, y in zip(inputs, targets):
            ctx.lut[y] = ctx.momentum * ctx.lut[y] + (1. - ctx.momentum) * x
            ctx.lut[y] /= ctx.lut[y].norm()
        return grad_inputs, None, None, None, None


def oim(inputs, targets, lut, momentum=0.5):
    return OIM.apply(inputs, targets, lut,torch.Tensor([momentum]).to(inputs.device))# momentum=momentum


class OIMLoss(nn.Module):
    def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
                 weight=None, size_average=True):
        super(OIMLoss, self).__init__()
        self.num_features = num_features  # 512
        self.num_classes = num_classes  # 625
        self.momentum = momentum  # 0.5
        self.scalar = scalar  # 30
        self.weight = weight  # None
        self.register_buffer('lut', torch.zeros(num_classes, num_features))
        self.size_average = size_average  # True

    def forward(self, inputs, targets):#
        # batchsize = inputs.size(0)
        # seq_len = inputs.size(1)
        # inputs = inputs.view(batchsize*seq_len, -1)
        # targets = targets.view(batchsize*seq_len)
        inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
        # inputs = score.expand_as(inputs) * inputs
        inputs *= self.scalar

        loss = F.cross_entropy(inputs, targets, weight=self.weight)
        return loss, inputs, self.lut

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值