AI算法成长练习第一篇——Task-Adaptive Negative Envision for Few-Shot Open-Set Recognition代码复现

论文代码复现

代码结构
在这里插入图片描述

Architectures

AttnClassifier.py

import torch.nn as nn
import torch
import torch.nn.functional as F

import numpy as np

class Classifier(nn.Module):
	def __init__(self, args, feat_dim, param_seam, train_weight_base=False):
		super(Classifier, self).__init__()
		
		#weight & Bias for Base
		self.train_weight_base = train_weight_base
		self.init_representation(param_seman)
		if train_weight_base:
			print('Enable training base class weights')
		
		self.calibrator = SupportCalibrator(nway=args.n-ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen)
		self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
		self.metric = Metric_Cosine()

	def forward(self, feature, cls_ids, test=False):
		## bs: features[0].size(0)
		## support_feat: bs*nway*nshot*D
		## query_feat: bs*(nway*nquery)*D
		## base_ids: bs*54
		(support_feat, query_feat, openset_feat) = features
		
		(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
		(supp_ids, base_ids) = cls_ids
		base_weight, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids, base_ids)
		support_feat = torch.mean(support_feat, dim=2)
		supp_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)

		# 修改的代码
		n = query_feat.size()[1]
		sup_list = []
		for i in range(0, n, 5):
			supp_fk = query_feat[:, i:i+5, :].contiguous()
			ss, _ = self.calibrator(supp_fk, base_weights, support_seman,base_seman)
			sup_list.append(ss)
		suppfake_protos = torch.cat(sup_list, dim=1)

		suppfake_protos = torch.mean(suppfake_protos, dim=1).view(nb, -1, ndim)
		new_supp_protos = torch.cat([supp_protos, suppfake_protos], dim=1)

		fakeclass_protos, recip_unit = self.open_generator(new_supp_protos, base_weights, support_seman, base_seman)
		cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
		
		query_cls_scores = self.metric(cls_protos, query_feat)
		openset_cls_scores = self.metric(cls_protos, openset_feat)
		
		test_cosine_scores = (query_cls_scores, openset_cls_scores)
		
		query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
		query_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
		funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)

		return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance 

	def init_representation(self, param_seman):
		(params, seman_dict) = param_seman
		self.weight_base = nn.Parameter(params['cls_classifier.weight'], requires_grad=self.train_weight_base)
		self.bias_base = nn.Parameter(params['cls_classifier.bias'],requires_grad=self.train_weight_base)
		self.weight_mem = nn.Parameter(params['cls_classifier.weight'].clone(), requires_grad=False)
		self.seman = {k:nn.Parameter(torch.from_numpy(v), requires_grad=False).float().cuda() for k,v in seman_dict_items()}
	
	def get_representation(self, cls_ids, base_ids, randpick=False):
		if base_ids is not None:
			base_weights = self.weight_base[base_ids, :]
			base_wgtmem = self.weight_mem[base_ids, :]
			base_seman = self.seman['base'][base_ids, :]
			supp_seman = self.seman['base'][cls_ids, :]
		else:
			bs = cls_ids.size(0)
			base_weights = self.weight_base.repeat(bs, 1, 1)
			base_wgtmem = self.weight_mem.repaet(bs, 1, 1)
			base_seman = self.seman['base'].repeat(bs, 1, 1)
			supp_seman = self.seman['novel_test'][cls_ids, :]
		if randpick:
			num_base = base_weights.shape[1]
			base_size = self.base_size
			idx = np.random.choice(list(range(num_base)), size=base_size, replace=False)
			base_weights = base_weights[:, idx, :]
			base_seman = base_seman[:, idx, :]
		return base_weights, base_wgtmem, base_seman, supp_seman

class SupportCalibrator(nn.Module):
	def __init__(self, nway, feat_dim, n_head=1, base_seman_calib=True, neg_gen_type='semang'):
		super(SupportCalibrator, self).__init__()
		self.nway = nway
		self.feat_dim = feat_dim
		self.base_seman_calib = base_seman_calib
		
		self.map_sem = nn.Sequential(nn.Linear(300, 300), nn.LeakyReLU(0.1), nn.Dropout(0.1), nn.Linear(300, 300))
		self.calibrator = MultiHeadAttention(feat_dim // n_head, feat_dim // n_head, (feat_dim, feat_dim))
		self.neg_gen_type = neg_gen_type
		if neg_gen_type == 'semang':
			self.task_visfuse = nn.Linear(feat_dim * 300, feat_dim)
			self.task_semfuse = nn.Linear(feat_dim*300, 300)

	def _seman_calib(self, seman):
		seman = self.map_sem(seman)
		return seman

	def forward(self, support_feat, base_weights, support_seman, base_seman):
		## support_feat: bs*nway*640, base_weights: bs*num_base*640, support_seman: bs*nway*300, base_seman: bs*num_base*300
		n_bs, n_base_cls = base_weights.size()[:2]
		base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.feat_dim)
		support_feat = support_feat_view(-1, 1, self.feat_dim)
		
		if self.neg_gen_type == 'semang':
			support_seman = self._seman_calib(support_seman)
			if self.base_seman_calib:
				base_seman = self._seman_calib(base_seman)
			base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, 300)
			support_seman = support_seman.view(-1, 1, 300)
			
			base_mem_vis = base_weights
			task_mem_vis = base_weights
			
			base_mem_seman = base_seman
			task_mem_seman = base_seman
			avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1), 1, keepdim=True)

			gate_vis = torch.sigmoid(self.task_visfuse(avg_task_mem)) + 1.0
			gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
			
			base_weights = base_mem_vis *gate_vis
			base_seman = base_mem_seman * gate_sem
			
		elif self.neg_gen_type == 'attg':
			base_mem_vis = base_weights
			base_seman = None
			support_seman = None
		elif self.neg_gen_type == 'att':
			base_weights = support_feat
			base_mem_vis = support_feat
			support_seman = None
			base_seman = None

		else:
			return support_feat.view(n_bs, self.nway, -1), None
		
		support_center, _, support_attn, _ = self.calibrator(support_feat, base_weights, base_mem_vis, support_seman, base_seman)
		support_center = support_center.view(n_bs, self.nway, -1)
		support_attn = support_attn.view(n_bs, self.nway, -1)
		return support_center, support_attn
		 
class OpenSetGenerater(nn.Module):
 	def __init__(self, nway, featdim, n_head=1, neg_gen_type='semang', agg='avg'):
 		supper(OpenSetGenerater, self).__init__()
 		self.nway = nway
 		self.att = MultiHeadAttention(featdim // n_head, featdim // n_head, (featdim, featdim))
 		self.featdim = featdim
 		self.neg_gen_type = neg_gen_type
 		if neg_gen_type == 'semang':
 			self.task_visfuse = nn.Linear(featdim+300, featdim)
 			self.task_semfuse = nn.Linear(featdim+300, 300)
 			
 		self.agg = agg
 		if agg == 'mlp':
 			self.agg_func = nn.Sequential(
			nn.Linear(featdim, featdim),
			nn.LeakyReLU(0.5),
			nn.Dropout(0.5),
			nn.Linear(featdim, featdim))
		self.map_sem = nn.Sequential(nn.Linear(300, 300),
									nn.LeakyReLU(0.1),
									nn.Dropout(0.1),
									nn.Linear(300, 300))
		
	def _seman_calib(self, seman):
		### feat: bs * d*feat_dim, seman: bs*d*300
		seman = self.map_sem(seman)
		return seman

	def forward(self, support_center, base_weights, support_seman=None, base_seman=None):
		## support_center: bs*nway*D
		## weight_base: bs*nbase*D
		bs = support_center.shape[0]
		n_bs, n_base_cls = base_weights.size()[:2]

		base_weights = base_weights.unsqueeze(dim=1).repeat(1, self.nway, 1, 1).view(-1, n_base_cls, self.featdim)
		support_center = support_center.view(-1, 1, self.featdim)
		
		if self.neg_gen_type == 'semang':
			support_seman = self._seman_calib(support_seman)
			base_seman = base_seman.unsqueeze(dim=1).repeat(1, self.nway,1, 1).view(-1, n_base_cls, 300)
			support_seman = support_seman.view(-1, 1, 300)
			base_mem_vis = base_weights
			task_mem_vis = base_weights
	
			base_mem_seman = base_seman
			task_mem_seman = base_seman
			avg_task_mem = torch.mean(torch.cat([task_mem_vis, task_mem_seman], -1),1,keepdim=True)
			
			gate_vis = torch.sigmoid(self.task_visfuse(task_mem_seman)) + 1.0
			gate_sem = torch.sigmoid(self.task_semfuse(avg_task_mem)) + 1.0
			
			base_weights = base_mem_vis * gate_vis
			base_seman = base_mem_seman * gate_vis
	
		elif self.neg_gen_type == 'attg':
			base_mem_vis = base_weights
			support_seman = None
			base_seman = None
			
		elif self.neg_gen_type == 'att':
			base_weights = support_center
			base_mem_vis = support_center
			support_seman = None
			base_seman = None
	
		else:
			fakeclass_center = support_center.mean(dim=0, keepdim=True)
			if self.agg == 'mlp':
				fakeclass_center = self.agg_func(fakeclass_center)
			return fakeclass_cneter, support_center.view(bs, -1, self.featdim)
		
		output, attcoef, attn_score, value = self.att(support_center, base_weights, base_mem_vis, support_seman, base_seman) ## bs*nway*nbase
		output = output.view(bs, -1, self.featdim)
		fakeclass_center = output.mean(dim=1, keepdim=True)
		
		if self.agg == 'mlp':
			fakeclass_center = self.agg_func(fakeclass_center)
		return fakeclass_center, support

class MultiHeadAttention(nn.Module):
	'''Multi-Head Attention module'''
	
	def __init__(self, d_k, d_v, d_model, n_head=1, dropout=0.1):
		super().__init__()
		self.n_head = n_head
		self.d_k = d_k
		self.d_v = d_v
		
		### Visaul feature projection head 
		self.w_qs = nn.Linear(d.model[0], n_head * d_k, bias=False)
		self.w_ks = nn.Linear(d_model[1], n_head * d_k, bias=False)
		self.w_vs = nn.Linear(d_model[-1], n_head * d_v, bias = False)
		nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model[0] + d_k)))
		nn.init_normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model[1] + d_k)))
		nn.init_normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / d_model[-1] + d_v))

		### Semantic projection head 
		self.w_qs_sem = nn.Linear(300, n_head * d_k, bias=False)
		self.w_ks_sem = nn.Linear(300, n_head * d_k, bias=False)
		self.w_vs_sem = nn.Linear(300, n_head * d_k, bias=False)
		
		nn.init.normal_(self.w_qs_sem.weight, mean=0, std=np.sqrt(2.0 / 600))
		nn.init.normal_(self.w_ks_sem.weight, mean=0, std=np.sqrt(2.0 / 600))
		nn.init.normal_(self.w_vs_sem.weight, mean=0, std=np.sqrt(2.0 / 600))

		self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
		self.fc = nn.Linear(n_head * d_v, d_model[0], bias=False)
		nn.init.xavier_normal_(self.fc.weight)
		self.dropout = nn.Dropout(dropout)

	def forward(self, q, k, v, q_sem=None, k_sem=None, mark_res=True):
		### q: bs*nway*D
		d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
		sz_b, len_q, _ = q.size()
		sz_b, len_k, _ = k.size()
		sz_b, len_v, _ =v.size()
		residual = q
		q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
		k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
		v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
		
		q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
		k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
		v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
		
		if q_sem is not None:
			sz_b, len_q, _ = q_sem.size()
			sz_b, len_k, _ = k_sem.size()
			q_sem = self.w_qs_sem(q_sem).view(sz_b, len_q, n_head, d_k)
			k_sem = self.w_ks_sem(k_sem).view(sz_b, len_k, n_head, d_k)
			q_sem = q_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)
			k_sem = k_sem.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)
		
		output, attn, attn_score = self.attention(q, k, v, q_sem, k_sem)

		output = output.view(n_head, sz_b, len_q, d_v)
		output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
		output = self.dropout(self.fc(output))
		if mark_res:
			output = output + residual
		return output, attn, attn_socre, v
		
class ScaledDotProductAttention(nn.Module):
	'''Scaled Dot-Product Attention'''
	
	def __init___(self, temperature, attn_dropout=0.1):
		super().__init__()
		self.temperature = temperature
		self.dropout = nn.Dropout(attn_dropout)
		self.softmax = nn.Softmax(dim=2)
	
	def forward(self, q, k, v, q_sem=None, k_sem=None):
		attn_score = torch.bmm(q, k.transpose(1, 2))
		if q_sem is not None:
			attn_sem = torch.bmm(q_sem, k_sem.transpose(1, 2))
			q = q + q_sem
			k = k + k_sem
			attn_socre = torch.bmm(q, k.transpose(1, 2))
		
		attn_socre /= self.temperature
		attn = self.softmax(attn_score)
		attn = self.dropout(attn)
		output = torch.bmm(attn, v)
		return output, attn, attn_score

class Metric_Cosine(nn.Module):
	def __init__(self, temperature=10):
		super(Metric_Cosine, self).__init__()
		self.temp = nn.Parameter(torch.tensor(float(temperature)))

	def forward(self, supp_center, query_feature):
		## supp_center: bs*nway*D
		## query_feature: bs*(nway*nquery)
		supp_center = F.normalize(supp_center, dim=-1) # eps=1e-6 default 1e-12
		query_feature = F.normalize(query_feature ,dim=-1)
		logits = torch.bmm(query_feature, supp_center.transpose(1, 2))
		return logits * self.temp

conv2d_mtl.py

# conv2d_mtl
import math
import torch
import torch.nn.function as F 
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair

class _ConvNdMtl(Module):
	"""The class for meta-transfer convolution """
	def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias):
		super(_ConvNdMtl, self).__init__()
		if in_channels % groups != 0:
			raise ValueError('in_channels must be divisible by groups')
		if out_channels % groups != 0:
			raise ValueError('out_channels must be dicisible by groups')
		self.in_channels = in_channels
		self.out_channels = out_channels
		self.kernel_size = kernel_size
		self.stride = stride
		self.padding = padding
		self.dilation = dilation
		self.transposed = transposed
		self.output_padding = output_padding
		self.groups = groups
		if transposed:
			self.weight = Parameter(torch.Tensor(in_channels, out_channels // groups, *kernel_size))
			self.mtl_weight = Parameter(torch.ones(in_channels, out_channels // groups , 1, 1))
		else:
			self.weight = Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
			self.mtl_weight = Parameter(torch.ones(out_channels, in_channels // groups, 1, 1))
		self.weight.requires_grad = False
		if bias:
			self.bias = Parameter(torch.Tensor(out_channels))
			self.bias.requires_grad = False
			self.mtl_bias = Parameter(torch.zeros(out_channels))
		else:
			self.register_parameter('bias', None)
			self.register_parameter('mtl_bias', None)
		self.reset_parameters()
	
	def reset_parameters(self):
		n = self.in_channels
		for k in self.kernel_size:
			n *= k
		stdv = 1. /math.sqrt(n)
		self.weight.data.uniform_(-stdv, stdv)
		self.mtl_weight.data.uniform(0, 0)
		if self.bias is not None:
			self.bias.data.uniform_(-stdv, stdv)
			self.mtl_bias.data.uniform_(0, 0)
			
	def extra_repr(self):
	s = ('{in_channels}, {out_channels}, kernel_size={kernerl_size}' ', stride={stride}')
	if self.padding != (0,) * len(self.padding):
		s += ', padding={padding}'
	if self.dilation != (1,) * len(self.dialtion):
		s += ', dilation={dilation}'
	if self.output_padding != (0,) * len(self.output_padding):
		s += 'output_padding={output_padding}'
	if self.groups != 1:
		s += ', groups={groups}'
	if self.bias is None:
		s += 'biase=False'
	return s.format(**self.__dict__)

class Conv2dMtl(_ConvNdMtl):
""" The class for meta-transfer convolution"""
def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, bias=True):
	kernel_size = _pair(kernel_size)
	stride = _pair(stride)
	padding = _pair(padding)
	dilation = _pair(dilation)
	super(Conv2dMtl, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias)

def forward(self, inp):
	new_mtl_weight = self.mtl_weight.expand(self.weight.shape)
	new_weight = self.weight.mul(new_mtl_weight)
	if self.bias is not None:
		new_bias = self.bias + self.mtl_bias
	else:
		new_bias = None
	return F.conv2d(inp, new_weight, new_bias, self.stride, self.padding, self.dilation, self.groups)

GAttnClassifier.py

from architectures.AttnClassifier import *

class GClassifier(Classifier):
	def __init__(self, args, feat_dim, param_seman, train_weight_base=False):
		supper(GClassifier, self).__init__(args, feat_dim, param_seman, train_weights_base)
		
		# weight & Bias for Base
		self.train_weight_base = train_weight_base
		self.init_representation(param_seman)
		if train_weight_base:
			print('Enable training base class weight')
		self.calibrator = SupportCalibator(nway=args.n_ways, feat_dim=feat_dim, n_head=1, base_seman_calib=args.base_seman_calib, neg_gen_type=args.neg_gen_type)
		self.open_generator = OpenSetGenerater(args.n_ways, feat_dim, n_head=1, neg_gen_type=args.neg_gen_type, agg=args.agg)
		self.metric = Metric_Cosine()

	def forward(self, feature, cls_ids, test=False):
		## bs:feature[0].size(0)
		## support_feat: bs*nway*shot*D
		## query_feat:bs*(nway*nquery)*D
		## base_ids: bs*54
		(support_feat, query_feat, openset_feat, base_feat) = features
		(nb, nc, ns, ndim), nq = support_feat.size(), query_feat.size(1)
		(supp_ids, base_ids) = cls_ids
		
		base_weights, base_wgtmem, base_seman, support_seman = self.get_representation(supp_ids.type(torch.int64), base_ids.type(torch.int64))
		support_feat = torch.mean(support_feat, dim=2)

		support_protos, support_attn = self.calibrator(support_feat, base_weights, support_seman, base_seman)
		
		fakeclass_protos, recip_unit = self.open_generator(supp_protos, base_weights, support_seman, base_seman)
		cls_protos = torch.cat([base_weights, supp_protos, fakesclass_protos], dim=1)

		query_funit_distance = 1.0 - self.metric(recip_unit, query_feat)
		qopen_funit_distance = 1.0 - self.metric(recip_unit, openset_feat)
		funit_distance = torch.cat([query_funit_distance, qopen_funit_distance], dim=1)
		
		query_cls_scores = self.metric(cls_protos, query_feat)
		openset_cls_scores = self.metric(cls_protos, openset_feat)
		baseset_cls_scores = self.metric(cls_protos, baseset_feat)
		
		test_cosine_scores = (baseset_cls_scores, query_cls_scores, openset_cls_scores)
		return test_cosine_scores, supp_protos, fakeclass_protos, (base_weights, base_wgtmem), funit_distance

GNetworkPre.py

import torch.nn as nn
import torch
import torch.nn.functional as F

from architectures.ResNetFeat import create_feature_extractor
from architectures.GAttnClassifier import GClassifier

class GFeatureNet(nn.Module):
	def __init__(self, args, restype, n_class, param_seman):
		super(GFeatureNet, self).__init__()
		self.args = args
		self.restype = restype
		self.n_class = n_class
		self.featype = args.featype
		self.n_ways = args.n_ways
		self.tunefeat = args.tunefeat
		self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long()
		self.metric = Metric_Cosine()

		self.feature = create_feature_extractor(restype, args.dataset)
		self.feat_dim = self.feature.out_dim

		self.cls_classifier = GClassifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'GOpenMeta' in self.featype else nn.Linear(self.feat_dim, n_class)
		
		assert 'GOpenMeta' in self.featype
		if self.tunefeat == 0.0:
			for _, p in self.feature.named_parameters():
				p.requires_grad = False
		else:
			if args.tune_part <= 3:
				for _, p in self.feature.layer1.named_parameters():
					p.requires_grad=False
			if args.tune_part <= 2:
				for _, p in self.feature.layer2.named_parameters():
					p.requires_grad=False
			if args.tune_part <= 1:
				for _, p in self.feature.layer3.named_parameters():
					p.requires_grad=False

	def forward(self, the_img, labels=None,conj_ids=None, base_ids=None, test=False):
		if labels is None:
			assert the_img.dim() == 4
			return (self.feature(the_img), None)
		else:
			return self.gen_open_forward(the_img, labels, conj_ids, base_ids, test)

		







ResNetFeat.py

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.distributions import Bernoulli
from .conv2d_mtl import Conv2dMtl

def conv3x3(in_planes, out_planes, stride=1):
	"""3x3 convolution with padding """
	return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv3x3mtl(in_planes, out_planes, kernel_size=1, stride=stride, padding=1, bias=False)
	return Conv2dMtl(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class DropBlock(nn.Module):
	def __init__(self, block_size):
		super(DropBlock, self).__init__()
		
		self.block_size = block_size
	
	def forward(self, x, gamma):
		# shape: (bsize, channels, height, width)
		if self.training:
			batch_size, channels, height, width = x.shape
			bernoulli = Bernoulli(gamma)
			mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), witdth - (self.block_size - 1))).cuda()
			block_mask = self._compute_block_mask(mask)
			countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
			return block_mask * x * (countM / count_ones)
		else:
			return x
	
	def _compute_black_mask(self, mask):
		left_padding = int((self.block_size - 1) / 2)
		right_padding = int(self.block_size / 2)
		
		batch_size, channels, height, width = mask.shape
		non_zero_idxs = mask.nonzero()
		nr_blocks = non_zero_idxs.shape[0]
		offsets = torch.stack(
		[
			torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # -  left_padding
			torch.arange(self.block_size).repeat(self.block_size), # - left_padding
		]
		).t().cuda()
		offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()),1)
			
		if nr_block > 0:
			non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
			offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
			offsets = offsets.long()
			block_idxs = non_zero_idxs + offsets
			# block_idxs += left_padding
			padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 
			padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
		else:
			padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
		block_mask = 1 - padded_mask #[:height, :width]
		return block_mask

class BasicBlock(nn.Module):
	expansion = 1
	
	def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate-0.0, drop_block=False, block_size=1):
		super(BasicBlock, self).__init__()
		self.conv1 = conv3x3(inplanes, planes)
		slef.bn1 = nn.BatchNorm2d(planes)
		self.relu = nn.LeakyReLU(0.1)
		self.conv2 = conv3x3(planes, planes)
		self.bn2 = nn.BatchNorm2d(planes)
		self.conv3 = conv3x3(planes, planes)
		self.bn3 = nn.BatchNorm2d(planes)
		self.maxpool = nn.MaxPool2d(planes)
		self.downsample = downsample
		self.stride = stride
		self.drop_rate = drop_rate
		self.num_batches_tracked = 0
		self.drop_block = drop_block
		self.block_size = block_size
		self.DropBlock = DropBlock(block_size=self.block_size)

	def forward(seld, x):
		self.num_batches_tracked += 1
		
		resiual = x
		
		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)
		
		out = self.conv2(out)
		out = self.bn2(out)
		out = self.relu(out)
		
		out = self.conv3(out)
		out = self.bn3(out)
		
		if self.downsample is not None:
			
			residual = self.daownsample(x)
			
		out += residual
		out = self.relu(out)
		out = self.maxpool(out)
		if self.drop_rate > 0:
			if self.drop_block == True:
				feat_size = out.size()[2]
				keep_rate = max(1.0 - self.drop_rate / (20 * 2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
				gamma = (1 - keep_rate) / self.block**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
			else:
				out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
		return out

class BasicBlockMeta(nn.Module):
	expansion = 1
	
	def __init__(self, inplanes, planes, stride=2, downsample=None, drop_rate=0.0, drop_rate=0.0, drop_block=False, block_size=1):
		super(BasicBlockMeta, self).__init__()
		self.conv1 = conv3x3mtl(inplanes, planes)
		self.bn1 = nn.BatchNorm2d(planes)
		self.relu = nn.LeakyReLU(0.1)
		self.conv2 = conv3x3mtl(planes, planes)
		self.bn2 = nn.BatchNorm2d(planes)
		self.conv3 = conv3x3mtl(planes, planes)
		self.bn3 = nn.BatchNorm2d(planes)
		self.maxpool = nn.MaxPool2d(stride)
		self.downsample = downsample
		self.stride = stride
		self.drop_rate = drop_rate
		self.num_batches_tracked = 0
		self.drop_block = drop_block
		self.block_size = block_size
		self.DropBlock = DropBlock(block_size=self.block_size)
	
	def forward(self, x):
		self.num_batches_tracked += 1
		residual = x
		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)
		
		out = self.conv2(out)
		out = self.bn2(out)
		out = self.relu(out)
		
		out = self.conv3(out)
		out = self.bn3(out)
		
		if self.downsample is not None:
			residual = self.downsample(x)
		out += residual
		out = self.relu(out)
		out = self.maxpool(out)
		
		if self.drop_rate > 0:
			if self.drop_block:
				feat_size = out.size()[2]
				keep_rate = max(1.0 - self.drop_rate / (20 *2000) * (self.num_batch_tracked), 1.0 - self.drop_rate)
				gamma = (1 - keep_rate) / self.block_size * feat_size**2 / (feat_size - self.block_size + 1)**2
			else:
				out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
		return out

class ResNet(nn.Module):
	def __init(self, block, n_blocks, keep_prob=1.0, drop_rate=0.0, dropblock_size=5, num_classes=-1):
		super(ResNet, self).__init__()
		channels = [64, 160, 320, 640]
		
		self.inplanes = 3
		self.layer1 = self._make_layer(block, n_blocks[0], channels[0], drop_rate=drop_rate)
		self.layer2 = self._make_layer(block, n_block[1], channels[1], drop_rate=drop_rate)
		self.layer3 = self._make_layer(block, n_blocks[2], channels[2], drop_rate=drop_rate, drop_block=True, block_size=dropblack_size)
		self.layer4 - self._make_layer(block, n_blocks[3], channels[3], drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
		self.avgpool = nn.AdaptiveAvgPool2d(1)
		self.keep_prob = keep_prob
		self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
		self.drop_rate = drop_rate
		self.out_dim = channels[-1]

		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				nn,init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
			elif isinstance(m, nn.BatchNorm2d):
				nn.init.constant_(m.weight, 1)
				nn.init.constant_(m.bias, 0)

	def _make_layer(self, block, n_block, planes, stride=2, drop_rate=0.0, drop_block=False, block_size=1):
	downsample = None
	if stride != 1 or self.inplanes != planes * block.expansion:
		downsample = nn.Sequential(
			nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bisa=False),
			nn.BatchNorm2d(planes * block.expansion),
		)
	the_blk = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)
	self.inplanes = planes * block.expansion
	return the_blk

	def forward(self, x, rot=False, is_feat=False):
		x = self.layer1(x)
		x = self.layer2(x)
		
		x = self.layer3(x)
		x = self.layer4(x)
	
		x = self.avgpool(x)
		resfeat = x.view(x.size(0), -1)

		return resfeat

def create_feature_extractor(restype, dataset, **kwargs):
	# mode 0:pre-train, 1:finetune 2:bias_shift
	keep_prob = 1.0
	drop_rate = 0.1
	dropblock_size = 5 if 'ImageNet' in dataset else 2
	if restype == 'ResNet12':
		network = ResNet(BasicBlock, [1,1,1,1], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs)
	elif restype == 'ResNet18':
		network = ResNet(BasicBlock, [1,1,2,2], keep_prob=keep_prob, drop_rate=drop_rate, dropblock_size=dropblock_size, **kwargs)
	else:
		raise ValueError("Not Implemented Yet")
	return network

NetworkPre

# NetworkPre
import torch.nn as nn
import torch 
import torch.nn.functional as F

import numpy as np

from architectures.ResNetFeat import create_feature_extractor
from architectures.AttnClassifier import Classifier

class FeatureNet(nn.Module):
	def __init__(self, args, restype, n_class, param_seman):
		super(FeatureNet, self).__init__()
		self.args = args
		self.restype = restype
		self.n_ways = args.n_ways
		self.tunefeat = args.tunefeat
		self.distance_label = torch.Tensor([i for i in range(self.n_ways)]).cuda().long()
		self.metric = Metric_Cosine()
		self.feature = create_feature_extractor(restype, args.dataset)
		self.feat_dim = self.feature.out_dim # 输出640维
		
		self.cls_classifier = Classifier(args, self.feat_dim, param_seman, args.train_weight_base) if 'Openmeta' in self.featype else nn.Linear(self.feat_dim, n_class)
		assert 'OpenMeta' in self.featype
		if self.tunefeat == 0.0:
			for _, p in self.feature_named_parameters():
				p.requires_grad = False
		else:
			if args.tune_part <= 3:
				for _, p in self.feature.layer1.named_parameters():
					p.requires_grad = False
			if args.tune_part <= 2:
				for _, p in self.feature.layer2.named_parameters():
					p.requires_grad = False
			if args.tune_part <= 1:
				for _, p in self.feature.layer3.named_parameters():
					p.requires_grad = False
	
	def forward(self, the_img, labels=None, conj_ids=None, base_ids=None, test=False):
		if labels is None:
			assert the_img.dim() == 4
			return (self.feature(the_img).None)
		else:
			return self.open_forward(the_img, labels, conj_ids, base_ids, test)
	
	def open_forward(self, the_input, labels, conj_ids, base_ids, test):
		# Hyper-parameter Preparation
		the_sizes = [_.size(1) for _ in the_input]
		(ne, _, nc, nh, nw) = the_input[0].size()
		
		# Data Preparation
		combined_data = torch.cat(the_input, dim=1).view(-1, nc, nh, nw)
		if not self.tunefeat: # tunefeat = 0
			with torch.no_grad():
				combined_feat = self.feature(combined_data).detach() # 禁用梯度计算
			
		else:
			combined_feat = self.feature(combined_data).detach()
		support_feature, query_feat, supopen_feat, openset_feat = torch.split(combined_feat.view(ne, -1, self.feat_dim), the_sizes, dim=1)
		(support_label, query_label, supopen_label, openset_label) = labels
		(supp_idx, open_idx) = conj_ids
		cls_label = torch.cat([query_label, openset_label], dim=1)
		test_feats = (support_feat, query_feat, openset_feat)
		
		### First Task
		support_feat = support_feat.view(ne, self.n_ways, -1, self.feat_dim)
		test_cosine_scores, supp_protos, fakeclass_protos,loss_cls, loss_funit = self.task_proto((support_feat, query_feat, openset_feat), (supp_idx, base_idx), cls_label, test)
		cls_protos = torch.cat([supp_protos, fakeclass_protos], dim=1)
		test_cls_probs = self.task_pred(test_cosine_scores[0], test_cosine_scores[1])

		if test:
			test_feats = (support_feat, query_feat, openset_feat)
			return test_feats, cls_protos, test_cls_probs
		
		## Second task
		supopen_feat = supopen_feat.view(ne, self.n_ways, -1, self.feat_dim)
		_, supp_protos_aug, fakeclass_protos_aug, loss_cls_aug, loss_funit_aug = self.task_proto((supopen_feat, openset_feat, query_feat), (open_idx, base_ids), cls_label, test)
		loss_open_hinge = 0.0
		loss = (loss_cls + loss_cls_aug, loss_open_hinge, loss_funit+loss_funit_aug)
		return test_feats, cls_protos, test_cls_probs, loss

	def task_proto(self, features, cls_ids, cls_label, test=False):
		test_cosine_scores, supp_protos, fakeclass_protos, _, funit_distance = self.cls_classifier(features, cls_ids, test)
		(query_cls_scores, openset_cls_scores) = test_cosine_scores
		cls_scores = torch.cat([query_cls_scores, openset_cls_scores], dim=1)
		fakeunit_loss = fakeunit_compare(funit_distance, self.n_ways, cls_label)
		cls_scores, close_label, cls_label = cls_scores.view(-1, self.n_ways+1), cls_labels[:, : query_cls_scores.size(1)].reshape(-1), cls_label.view(-1)
		loss_cls = F.cross_entropy(cls_scores, cls_label)
		return test_cosine_scores, supp_protos, fakeclass_protos, loss_cls, fakeunit_loss

	def task_pred(self, query_cls_scores, openset_cls_scores, many_cls_scores=None):
		query_cls_probs = F.softmax(query_cls_scores.detach(), dim=-1)
		openset_cls_probs = F.softmax(openset_cls_scores.detach(), dim=1)
		if many_cls_scores is None:
			return (query_cls_probs, openset_cls_probs)
		else:
			many_cls_probs = F.softmax(many_cls_scores.detach(), dim=-1)
			return (query_cls_probs, openset_cls_probs, many_cls_probs, query_cls_scores, openset_cls_scores)



class Metric_Cosine(nn.Module):
	def __init__(self, temperature=10):
		super(Metric_Cosine, self).__init__()
		self.temp = nn.Parameter(torch.tensor(float(temperature)))

	def forward(self, supp_center, query_feature):
		## supp_center: bs*nway*D
		## query_feature: bs * (nway*nquery) * D
		supp_center = F.normalize(supp_center, dim=-1)
		logits = torch.bmm(query_feature, supp_center.transpose(1, 2))
		return logits * self.temp
	
def fakeunit_compare(funit_distance, n_ways, cls_label):
	cls_label_binary = F.one_hot(cls_label).float() # 修改的代码
	loss = torch.sum(F.binary_cross_entropy_with_logits(input=funit_distance, target=cls_label_binary))
	return loss

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值