论文代码复现
代码结构
Architectures
AttnClassifier.py
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
class Classifier(nn.Module):
def __init__(self, args, feat_dim, param_seam, train_weight_base=False):
super(Classifier, self).__init__()
#weight & Bias for Base
self.train_weight_base = train_weight_base
self.init_representation(param_seman)
if train_weight_base:
print('Enable training base class weights')
self.calibrator = SupportCalibrator(nway=args.n-ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen)
self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
self.metric = Metric_Cosine()
def forward(self, feature, cls_ids, test=False):
## bs: features[0].size(0)
## support_feat: bs*nway*nshot*D
## query_feat: bs*(nway*nquery)*D
## base_ids: bs*54
(support_feat, query_feat, openset_feat) = features
(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
(supp_ids, base_ids) = cls_ids
base_weight, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids, base_ids)
support_feat = torch.mean(support_feat, dim=2)
supp_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)
# 修改的代码
n = query_feat.size()[1]
sup_list = []
for i in range(0, n, 5):
supp_fk = query_feat[:, i:i+5, :].contiguous()
ss, _ = self.calibrator(supp_fk, base_weights, support_seman,base_seman)
sup_list.append(ss)
suppfake_protos = torch.cat(sup_list, dim=1)
suppfake_protos = torch.mean(suppfake_protos, dim=1).view(nb, -1, ndim)
new_supp_protos = torch.cat([supp_protos, suppfake_protos], dim=1)
fakeclass_protos, recip_unit = self.open_generator(new_supp_protos, base_weights, support_seman, base_seman)
cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
query_cls_scores = self.metric(cls_protos, query_feat)
openset_cls_scores = self.metric(cls_protos, openset_feat)
test_cosine_scores = (query_cls_scores, openset_cls_scores)
query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
query_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)
return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance
def init_representation(self, param_seman):
(params, seman_dict) = param_seman
self.weight_base = nn.Parameter(params['cls_classifier.weight'], requires_grad=self.train_weight_base)
self.bias_base = nn.Parameter(params['cls_classifier.bias'],requires_grad=self.train_weight_base)
self.weight_mem = nn.Parameter(params['cls_classifier.weight'].clone(), requires_grad=False)
self.seman = {k:nn.Parameter(torch.from_numpy(v), requires_grad=False).float().cuda() for k,v in seman_dict_items()}
def get_representation(self, cls_ids, base_ids, randpick=False):
if base_ids is not None:
base_weights = self.weight_base[base_ids, :]
base_wgtmem = self.weight_mem[base_ids, :]
base_seman = self.seman['base'][base_ids, :]
supp_seman = self.seman['base'][cls_ids, :]
else:
bs = cls_ids.size(0)
base_weights = self.weight_base.repeat(bs, 1, 1)
base_wgtmem = self.weight_mem.repaet(bs, 1, 1)
base_seman = self.seman['base'].repeat(bs, 1, 1)
supp_seman = self.seman['novel_test'][cls_ids, :]
if randpick:
num_base = base_weights.shape[1]
base_size = self.base_size
idx = np.random.choice(list(range(num_base)), size=base_size, replace=False)
base_weights = base_weights[:, idx, :]
base_seman = base_seman[:, idx, :]
return base_weights, base_wgtmem, base_seman, supp_seman
class SupportCalibrator(nn.Module):
def __init__(self, nway, feat_dim, n_head=1, base_seman_calib=True, neg_gen_type='semang'):
super(SupportCalibrator, self).__init__()
self.nway = nway
self.feat_dim = feat_dim
self.base_seman_calib = base_seman_calib
self.map_sem = nn.Sequential(nn.Linear(300, 300), nn.LeakyReLU(0.1), nn.Dropout(0.1), nn.Linear(300, 300))
self.calibrator = MultiHeadAttention(feat_dim // n_head, feat_dim // n_head, (feat_dim, feat_dim))
self.neg_gen_type = neg_gen_type
if neg_gen_type == 'semang':
self.task_visfuse = nn.Linear(feat_dim * 300, feat_dim)
self.task_semfuse = nn.Linear(feat_dim*300, 300)
def _seman_calib(self, seman):
seman = self.map_sem(seman)
return seman
def forward(self, support_feat, base_weights, support_seman, base_seman):
## support_feat: bs*nway*640, base_weights: bs*num_base*640, support_seman: bs*nway*300, base_seman: bs*num_base*300
n_bs, n_base_cls = base_weights.size()[:2]
base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.feat_dim)
support_feat = support_feat_view(-1, 1, self.feat_dim)
if self.neg_gen_type == 'semang':
support_seman = self._seman_calib(support_seman)
if self.base_seman_calib:
base_seman = self._seman_calib(base_seman)
base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, 300)
support_seman = support_seman.view(-1, 1, 300)
base_mem_vis = base_weights
task_mem_vis = base_weights
base_mem_seman = base_seman
task_mem_seman = base_seman
avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1), 1, keepdim=True)
gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0
gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
base_weights = base_mem_vis *gate_vis
base_seman = base_mem_seman * gate_sem
elif self.neg_gen_type == 'attg':
base_mem_vis = base_weights
base_seman = None
support_seman = None
elif self.neg_gen_type == 'att':
base_weights = support_feat
base_mem_vis = support_feat
support_seman = None
base_seman = None
else:
return support_feat.view(n_bs, self.nway, -1), None
support_center, _, support_attn, _ = self.calibrator(support_feat, base_weights, base_mem_vis, support_seman, base_seman)
support_center = support_center.view(n_bs, self.nway, -1)
support_attn = support_attn.view(n_bs, self.nway, -1)
return support_center, support_attn
class OpenSetGenerater(nn.Module):
def __init__(self, nway, featdim, n_head=1, neg_gen_type='semang', agg='avg'):
supper(OpenSetGenerater, self).__init__()
self.nway = nway
self.att = MultiHeadAttention(featdim // n_head, featdim // n_head, (featdim, featdim))
self.featdim = featdim
self.neg_gen_type = neg_gen_type
if neg_gen_type == 'semang':
self.task_visfuse = nn.Linear(featdim+300, featdim)
self.task_semfuse = nn.Linear(featdim+300, 300)
self.agg = agg
if agg == 'mlp':
self.agg_func = nn.Sequential(
nn.Linear(featdim, featdim),
nn.LeakyReLU(0.5),
nn.Dropout(0.5),
nn.Linear(featdim, featdim))
self.map_sem = nn.Sequential(nn.Linear(300, 300),
nn.LeakyReLU(0.1),
nn.Dropout(0.1),
nn.Linear(300, 300))
def _seman_calib(self, seman):
### feat: bs * d*feat_dim, seman: bs*d*300
seman = self.map_sem(seman)
return seman
def forward(self, support_center, base_weights, support_seman=None, base_seman=None):
## support_center: bs*nway*D
## weight_base: bs*nbase*D
bs = support_center.shape[0]
n_bs, n_base_cls = base_weights.size()[:2]
base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.featdim)
support_center = support_center.view(-1, 1, self.featdim)
if self.neg_gen_type == 'semang':
support_seman = self._seman_calib(support_seman)
base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway,1, 1).view(-1, n_base_cls, 300)
support_seman = support_seman.view(-1, 1, 300)
base_mem_vis = base_weights
task_mem_vis = base_weights
base_mem_seman = base_seman
task_mem_seman = base_seman
avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1),1,keepdim=True)
gate_vis = torch.sigmoid(self.task_visfuse(task_mem_seman)) + 1.0
gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
base_weights = base_mem_vis * gate_vis
base_seman = base_mem_seman * gate_vis
elif self.neg_gen_type == 'attg':
base_mem_vis = base_weights
support_seman = None
base_seman = None
elif self.neg_gen_type == 'att':
base_weights = support_center
base_mem_vis = support_center
support_seman = None
base_seman = None
else:
fakeclass_center = support_center.mean(dim=0, keepdim=True)
if self.agg == 'mlp':
fakeclass_center = self.agg_func(fakeclass_center)
return fakeclass_cneter, support_center.view(bs, -1, self.featdim)
output, attcoef, attn_score, value = self.att(support_center, base_weights, base_mem_vis, support_seman, base_seman) ## bs*nway*nbase
output = output.view(bs, -1, self.featdim)
fakeclass_center = output.mean(dim=1, keepdim=True)
if self.agg == 'mlp':
fakeclass_center = self.agg_func(fakeclass_center)
return fakeclass_center, support
class MultiHeadAttention(nn.Module):
'''Multi-Head Attention module'''
def __init__(self, d_k, d_v, d_model, n_head=1, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
### Visaul feature projection head
self.w_qs = nn.Linear(d.model[0], n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model[1], n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model[-1], n_head * d_v, bias = False)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model[0] + d_k)))
nn.init_normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model[1] + d_k)))
nn.init_normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / d_model[-1] + d_v))
### Semantic projection head
self.w_qs_sem = nn.Linear(300, n_head * d_k, bias=False)
self.w_ks_sem = nn.Linear(300, n_head * d_k, bias=False)
self.w_vs_sem = nn.Linear(300, n_head * d_k, bias=False)
nn.init.normal_(self.w_qs_sem.weight, mean=0, std=np.sqrt(2.0 / 600))
nn.init.normal_(self.w_ks_sem.weight, mean=0, std=np.sqrt(2.0 / 600))
nn.init.normal_(self.w_vs_sem.weight, mean=0, std=np.sqrt(2.0 / 600))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.fc = nn.Linear(n_head * d_v, d_model[0], bias=False)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, q_sem=None, k_sem=None, mark_res=True):
### q: bs*nway*D
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ =v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
if q_sem is not None:
sz_b, len_q, _ = q_sem.size()
sz_b, len_k, _ = k_sem.size()
q_sem = self.w_qs_sem(q_sem).view(sz_b, len_q, n_head, d_k)
k_sem = self.w_ks_sem(k_sem).view(sz_b, len_k, n_head, d_k)
q_sem = q_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)
k_sem = k_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)
output, attn, attn_score = self.attention(q, k, v, q_sem, k_sem)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
if mark_res:
output = output + residual
return output, attn, attn_socre, v
class ScaledDotProductAttention(nn.Module):
'''Scaled Dot-Product Attention'''
def __init___(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, q_sem=None, k_sem=None):
attn_score = torch.bmm(q, k.transpose(1, 2))
if q_sem is not None:
attn_sem = torch.bmm(q_sem, k_sem.transpose(1, 2))
q = q + q_sem
k = k + k_sem
attn_socre = torch.bmm(q, k.transpose(1, 2))
attn_socre /= self.temperature
attn = self.softmax(attn_score)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn, attn_score
class Metric_Cosine(nn.Module):
def __init__(self, temperature=10):
super(Metric_Cosine, self).__init__()
self.temp = nn.Parameter(torch.tensor(float(temperature)))
def forward(self, supp_center, query_feature):
## supp_center: bs*nway*D
## query_feature: bs*(nway*nquery)
supp_center = F.normalize(supp_center, dim=-1) # eps=1e-6 default 1e-12
query_feature = F.normalize(query_feature ,dim=-1)
logits = torch.bmm(query_feature, supp_center.transpose(1, 2))
return logits * self.temp
conv2d_mtl.py
# conv2d_mtl
import math
import torch
import torch.nn.function as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
class _ConvNdMtl(Module):
"""The class for meta-transfer convolution """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias):
super(_ConvNdMtl, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be dicisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
if transposed:
self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *kernel_size))
self.mtl_weight = Parameter(torch.ones(in_channels, out_channels // groups , 1, 1))
else:
self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
self.mtl_weight = Parameter(torch.ones(out_channels, in_channels // groups, 1, 1))
self.weight.requires_grad = False
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
self.bias.requires_grad = False
self.mtl_bias = Parameter(torch.zeros(out_channels))
else:
self.register_parameter('bias', None)
self.register_parameter('mtl_bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. /math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.mtl_weight.data.uniform(0, 0)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
self.mtl_bias.data.uniform_(0, 0)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernerl_size}' ', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dialtion):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += 'output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += 'biase=False'
return s.format(**self.__dict__)
class Conv2dMtl(_ConvNdMtl):
""" The class for meta-transfer convolution"""
def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2dMtl, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias)
def forward(self, inp):
new_mtl_weight = self.mtl_weight.expand(self.weight.shape)
new_weight = self.weight.mul(new_mtl_weight)
if self.bias is not None:
new_bias = self.bias + self.mtl_bias
else:
new_bias = None
return F.conv2d(inp, new_weight, new_bias, self.stride, self.padding, self.dilation, self.groups)
GAttnClassifier.py
from architectures.AttnClassifier import *
class GClassifier(Classifier):
def __init__(self, args, feat_dim, param_seman, train_weight_base=False):
supper(GClassifier, self).__init__(args, feat_dim, param_seman, train_weights_base)
# weight & Bias for Base
self.train_weight_base = train_weight_base
self.init_representation(param_seman)
if train_weight_base:
print('Enable training base class weight')
self.calibrator = SupportCalibator(nway=args.n_ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen_type=args.neg_gen_type)
self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
self.metric = Metric_Cosine()
def forward(self, feature, cls_ids, test=False):
## bs:feature[0].size(0)
## support_feat: bs*nway*shot*D
## query_feat:bs*(nway*nquery)*D
## base_ids: bs*54
(support_feat, query_feat, openset_feat, base_feat) = features
(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
(supp_ids, base_ids) = cls_ids
base_weights, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids.type(torch.int64), base_ids.type(torch.int64))
support_feat = torch.mean(support_feat, dim=2)
support_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)
fakeclass_protos, recip_unit = self.open_generator(supp_protos, base_weights, support_seman, base_seman)
cls_protos = torch.cat([base_weights, supp_protos, fakesclass_protos], dim=1)
query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
qopen_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)
query_cls_scores = self.metric(cls_protos, query_feat)
openset_cls_scores = self.metric(cls_protos, openset_feat)
baseset_cls_scores = self.metric(cls_protos, baseset_feat)
test_cosine_scores = (baseset_cls_scores, query_cls_scores, openset_cls_scores)
return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance
GNetworkPre.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from architectures.ResNetFeat import create_feature_extractor
from architectures.GAttnClassifier import GClassifier
class GFeatureNet(nn.Module):
def __init__(self, args, restype, n_class, param_seman):
super(GFeatureNet, self).__init__()
self.args = args
self.restype = restype
self.n_class = n_class
self.featype = args.featype
self.n_ways = args.n_ways
self.tunefeat = args.tunefeat
self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long()
self.metric = Metric_Cosine()
self.feature = create_feature_extractor(restype, args.dataset)
self.feat_dim = self.feature.out_dim
self.cls_classifier = GClassifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'GOpenMeta' in self.featype else nn.Linear(self.feat_dim, n_class)
assert 'GOpenMeta' in self.featype
if self.tunefeat == 0.0:
for _, p in self.feature.named_parameters():
p.requires_grad = False
else:
if args.tune_part <= 3:
for _, p in self.feature.layer1.named_parameters():
p.requires_grad=False
if args.tune_part <= 2:
for _, p in self.feature.layer2.named_parameters():
p.requires_grad=False
if args.tune_part <= 1:
for _, p in self.feature.layer3.named_parameters():
p.requires_grad=False
def forward(self, the_img, labels=None,conj_ids=None, base_ids=None, test=False):
if labels is None:
assert the_img.dim() == 4
return (self.feature(the_img), None)
else:
return self.gen_open_forward(the_img, labels, conj_ids, base_ids, test)
ResNetFeat.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli
from .conv2d_mtl import Conv2dMtl
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding """
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv3x3mtl(in_planes, out_planes, kernel_size=1, stride=stride, padding=1, bias=False)
return Conv2dMtl(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class DropBlock(nn.Module):
def __init__(self, block_size):
super(DropBlock, self).__init__()
self.block_size = block_size
def forward(self, x, gamma):
# shape: (bsize, channels, height, width)
if self.training:
batch_size, channels, height, width = x.shape
bernoulli = Bernoulli(gamma)
mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), witdth - (self.block_size - 1))).cuda()
block_mask = self._compute_block_mask(mask)
countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
return block_mask * x * (countM / count_ones)
else:
return x
def _compute_black_mask(self, mask):
left_padding = int((self.block_size - 1) / 2)
right_padding = int(self.block_size / 2)
batch_size, channels, height, width = mask.shape
non_zero_idxs = mask.nonzero()
nr_blocks = non_zero_idxs.shape[0]
offsets = torch.stack(
[
torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding
torch.arange(self.block_size).repeat(self.block_size), # - left_padding
]
).t().cuda()
offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()),1)
if nr_block > 0:
non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
offsets = offsets.long()
block_idxs = non_zero_idxs + offsets
# block_idxs += left_padding
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
else:
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
block_mask = 1 - padded_mask #[:height, :width]
return block_mask
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate-0.0, drop_block=False, block_size=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes)
slef.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.LeakyReLU(0.1)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv3x3(planes, planes)
self.bn3 = nn.BatchNorm2d(planes)
self.maxpool = nn.MaxPool2d(planes)
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.num_batches_tracked = 0
self.drop_block = drop_block
self.block_size = block_size
self.DropBlock = DropBlock(block_size=self.block_size)
def forward(seld, x):
self.num_batches_tracked += 1
resiual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.daownsample(x)
out += residual
out = self.relu(out)
out = self.maxpool(out)
if self.drop_rate > 0:
if self.drop_block == True:
feat_size = out.size()[2]
keep_rate = max(1.0 - self.drop_rate / (20 * 2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
gamma = (1 - keep_rate) / self.block**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
else:
out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
return out
class BasicBlockMeta(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate=0.0, drop_rate=0.0, drop_block=False, block_size=1):
super(BasicBlockMeta, self).__init__()
self.conv1 = conv3x3mtl(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.LeakyReLU(0.1)
self.conv2 = conv3x3mtl(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv3x3mtl(planes, planes)
self.bn3 = nn.BatchNorm2d(planes)
self.maxpool = nn.MaxPool2d(stride)
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.num_batches_tracked = 0
self.drop_block = drop_block
self.block_size = block_size
self.DropBlock = DropBlock(block_size=self.block_size)
def forward(self, x):
self.num_batches_tracked += 1
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
out = self.maxpool(out)
if self.drop_rate > 0:
if self.drop_block:
feat_size = out.size()[2]
keep_rate = max(1.0 - self.drop_rate / (20 *2000) * (self.num_batch_tracked), 1.0 - self.drop_rate)
gamma = (1 - keep_rate) / self.block_size * feat_size**2 / (feat_size - self.block_size + 1)**2
else:
out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
return out
class ResNet(nn.Module):
def __init(self, block, n_blocks, keep_prob=1.0, drop_rate=0.0, dropblock_size=5, num_classes=-1):
super(ResNet, self).__init__()
channels = [64, 160, 320, 640]
self.inplanes = 3
self.layer1 = self._make_layer(block, n_blocks[0], channels[0], drop_rate=drop_rate)
self.layer2 = self._make_layer(block, n_block[1], channels[1], drop_rate=drop_rate)
self.layer3 = self._make_layer(block, n_blocks[2], channels[2], drop_rate=drop_rate, drop_block=True, block_size=dropblack_size)
self.layer4 - self._make_layer(block, n_blocks[3], channels[3], drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.keep_prob = keep_prob
self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
self.drop_rate = drop_rate
self.out_dim = channels[-1]
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn,init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, n_block, planes, stride=2, drop_rate=0.0, drop_block=False, block_size=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bisa=False),
nn.BatchNorm2d(planes * block.expansion),
)
the_blk = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)
self.inplanes = planes * block.expansion
return the_blk
def forward(self, x, rot=False, is_feat=False):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
resfeat = x.view(x.size(0), -1)
return resfeat
def create_feature_extractor(restype, dataset, **kwargs):
# mode 0:pre-train, 1:finetune 2:bias_shift
keep_prob = 1.0
drop_rate = 0.1
dropblock_size = 5 if 'ImageNet' in dataset else 2
if restype == 'ResNet12':
network = ResNet(BasicBlock, [1,1,1,1], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs)
elif restype == 'ResNet18':
network = ResNet(BasicBlock, [1,1,2,2], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs)
else:
raise ValueError("Not Implemented Yet")
return network
NetworkPre
# NetworkPre
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from architectures.ResNetFeat import create_feature_extractor
from architectures.AttnClassifier import Classifier
class FeatureNet(nn.Module):
def __init__(self, args, restype, n_class, param_seman):
super(FeatureNet, self).__init__()
self.args = args
self.restype = restype
self.n_ways = args.n_ways
self.tunefeat = args.tunefeat
self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long()
self.metric = Metric_Cosine()
self.feature = create_feature_extractor(restype, args.dataset)
self.feat_dim = self.feature.out_dim # 输出640维
self.cls_classifier = Classifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'Openmeta' in self.featype else nn.Linear(self.feat_dim, n_class)
assert 'OpenMeta' in self.featype
if self.tunefeat == 0.0:
for _, p in self.feature_named_parameters():
p.requires_grad = False
else:
if args.tune_part <= 3:
for _, p in self.feature.layer1.named_parameters():
p.requires_grad = False
if args.tune_part <= 2:
for _, p in self.feature.layer2.named_parameters():
p.requires_grad = False
if args.tune_part <= 1:
for _, p in self.feature.layer3.named_parameters():
p.requires_grad = False
def forward(self, the_img, labels=None, conj_ids=None, base_ids=None, test=False):
if labels is None:
assert the_img.dim() == 4
return (self.feature(the_img).None)
else:
return self.open_forward(the_img, labels, conj_ids, base_ids, test)
def open_forward(self, the_input, labels, conj_ids, base_ids, test):
# Hyper-parameter Preparation
the_sizes = [_.size(1) for _ in the_input]
(ne, _, nc, nh, nw) = the_input[0].size()
# Data Preparation
combined_data = torch.cat(the_input, dim=1).view(-1, nc, nh, nw)
if not self.tunefeat: # tunefeat = 0
with torch.no_grad():
combined_feat = self.feature(combined_data).detach() # 禁用梯度计算
else:
combined_feat = self.feature(combined_data).detach()
support_feature, query_feat, supopen_feat, openset_feat = torch.split(combined_feat.view(ne, -1, self.feat_dim), the_sizes, dim=1)
(support_label, query_label, supopen_label, openset_label) = labels
(supp_idx, open_idx) = conj_ids
cls_label = torch.cat([query_label, openset_label], dim=1)
test_feats = (support_feat, query_feat, openset_feat)
### First Task
support_feat = support_feat.view(ne, self.n_ways, -1, self.feat_dim)
test_cosine_scores, supp_protos, fakeclass_protos,loss_cls, loss_funit = self.task_proto((support_feat, query_feat, openset_feat), (supp_idx, base_idx), cls_label, test)
cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
test_cls_probs = self.task_pred(test_cosine_scores[0], test_cosine_scores[1])
if test:
test_feats = (support_feat, query_feat, openset_feat)
return test_feats, cls_protos, test_cls_probs
## Second task
supopen_feat = supopen_feat.view(ne, self.n_ways, -1, self.feat_dim)
_, supp_protos_aug, fakeclass_protos_aug, loss_cls_aug, loss_funit_aug = self.task_proto((supopen_feat, openset_feat, query_feat), (open_idx, base_ids), cls_label, test)
loss_open_hinge = 0.0
loss = (loss_cls + loss_cls_aug, loss_open_hinge, loss_funit+loss_funit_aug)
return test_feats, cls_protos, test_cls_probs, loss
def task_proto(self, features, cls_ids, cls_label, test=False):
test_cosine_scores, supp_protos, fakeclass_protos, _, funit_distance = self.cls_classifier(features, cls_ids, test)
(query_cls_scores, openset_cls_scores) = test_cosine_scores
cls_scores = torch.cat([query_cls_scores, openset_cls_scores], dim=1)
fakeunit_loss = fakeunit_compare(funit_distance, self.n_ways, cls_label)
cls_scores, close_label, cls_label = cls_scores.view(-1, self.n_ways+1), cls_labels[:, : query_cls_scores.size(1)].reshape(-1), cls_label.view(-1)
loss_cls = F.cross_entropy(cls_scores, cls_label)
return test_cosine_scores, supp_protos, fakeclass_protos, loss_cls, fakeunit_loss
def task_pred(self, query_cls_scores, openset_cls_scores, many_cls_scores=None):
query_cls_probs = F.softmax(query_cls_scores.detach(), dim=-1)
openset_cls_probs = F.softmax(openset_cls_scores.detach(), dim=1)
if many_cls_scores is None:
return (query_cls_probs, openset_cls_probs)
else:
many_cls_probs = F.softmax(many_cls_scores.detach(), dim=-1)
return (query_cls_probs, openset_cls_probs, many_cls_probs, query_cls_scores, openset_cls_scores)
class Metric_Cosine(nn.Module):
def __init__(self, temperature=10):
super(Metric_Cosine, self).__init__()
self.temp = nn.Parameter(torch.tensor(float(temperature)))
def forward(self, supp_center, query_feature):
## supp_center: bs*nway*D
## query_feature: bs * (nway*nquery) * D
supp_center = F.normalize(supp_center, dim=-1)
logits = torch.bmm(query_feature, supp_center.transpose(1, 2))
return logits * self.temp
def fakeunit_compare(funit_distance, n_ways, cls_label):
cls_label_binary = F.one_hot(cls_label).float() # 修改的代码
loss = torch.sum(F.binary_cross_entropy_with_logits(input=funit_distance, target=cls_label_binary))
return loss