理论
监督学习:技术相对成熟,但是对海量的数据进行标记需要花费大量的时间和资源
无监督学习:自主地从大量数据中学习同类数据的相同特性,并将其编码为高级表征,再根据不同任务进行微调即可,节省时间以及硬件资源。
生成式学习:
生成式学习以自编码器(例如GAN,VAE等等)这类方法为代表,由数据生成数据,使之在整体或者高级语义上与训练数据相近。
对比式学习:
对比式学习着重于学习同类实例之间的共同特征,区分非同类实例之间的不同之处。
与生成式学习比较,对比式学习不需要关注实例上繁琐的细节,只需要在抽象语义级别的特征空间上学会对数据的区分即可,因此模型以及其优化变得更加简单,且泛化能力更强。
用聚类的思想来理解:
d
(
f
(
x
)
,
f
(
x
+
)
)
≪
d
(
f
(
x
)
,
f
(
x
−
)
)
O
R
s
(
f
(
x
)
,
f
(
x
+
)
)
≫
s
(
f
(
x
)
,
f
(
x
−
)
)
d(f(x),f(x^+))\ll d(f(x),f(x^-)) \\ OR \\ s(f(x),f(x^+))\gg s(f(x),f(x^-))
d(f(x),f(x+))≪d(f(x),f(x−))ORs(f(x),f(x+))≫s(f(x),f(x−))
- 缩小类内的距离,扩大类外的距离
丈量二者距离:欧几里得距离,余弦相似度,马氏距离 …
目标:给定锚点,通过空间变换,使得锚点与正样本间距离尽可能小,与负样本距离尽可能大,这个应该是triptloss的思想
对比损失
W
:网络权重;
W :网络权重;
W:网络权重;
Y
:
L
a
b
e
l
,
Y :Label,
Y:Label,
Y
=
{
0
,
X
1
,
X
2
同类
1
,
X
1
,
X
2
不同类
Y= \begin{cases} 0,\quad X_1,X_2同类\\ 1, \quad X_1,X_2不同类 \end{cases}\\
Y={0,X1,X2同类1,X1,X2不同类
D W :是 X 1 与 X 2 在潜变量空间的欧几里德距离。 D_W :是 X_1 与 X_2 在潜变量空间的欧几里德距离。 DW:是X1与X2在潜变量空间的欧几里德距离。
i :表示第 i 组向量对。 i :表示第i组向量对。 i:表示第i组向量对。
L
:研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。
L :研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。
L:研究中常常在这里做文章,定义合理的能够完成最终目标的损失函数往往就成功了大半。
L
(
W
,
(
Y
,
X
1
⃗
,
X
2
⃗
)
i
)
=
(
1
−
Y
)
L
S
(
D
W
i
(
X
1
⃗
,
X
2
⃗
)
)
+
Y
L
D
(
D
W
i
(
X
1
⃗
,
X
2
⃗
)
)
L
(
W
)
=
∑
i
=
1
P
L
(
W
,
(
Y
,
X
1
⃗
,
X
2
⃗
)
i
)
L(W,(Y,\vec{X_1},\vec{X_2})^i)=(1-Y)L_S(D_W^i(\vec{X_1},\vec{X_2}))+YL_D(D_W^i(\vec{X_1},\vec{X_2}))\\\ L(W)=\sum^P_{i=1}L(W,(Y,\vec{X_1},\vec{X_2})^i) \\
L(W,(Y,X1,X2)i)=(1−Y)LS(DWi(X1,X2))+YLD(DWi(X1,X2)) L(W)=i=1∑PL(W,(Y,X1,X2)i)
正样本:
当与锚点是正样本时,由于对比思想,二者之间会逐渐靠近。
原文将它假设成一个原长 l
→
0
\rightarrow 0
→0 的弹簧,那么就会将正样本无限的拉近,从而完成聚类。
F
⃗
=
−
x
⃗
将锚点设为势能零点:
E
=
0
−
∫
F
⃗
d
x
⃗
=
1
2
x
2
那么
E
即可作为
L
S
,且满足定义要求:
L
S
=
1
2
D
W
2
\vec{F}=-\vec{x}\\ 将锚点设为势能零点: E=0-\int\vec{F}d\vec{x}=\frac 1 2 x^2\\ 那么 E 即可作为L_S ,且满足定义要求:L_S=\frac 1 2 D_W^2\\
F=−x将锚点设为势能零点:E=0−∫Fdx=21x2那么E即可作为LS,且满足定义要求:LS=21DW2
负样本:
当与锚点是负样本时,由于对比思想,二者之间会逐渐原理。原文将它假设成一个原长 l → m \rightarrow m →m 的弹簧,那么就会将负样本至少拉至m,从而完成划分。
F
⃗
=
m
⃗
−
x
⃗
将锚点设为势能零点:
E
=
0
−
∫
F
⃗
d
x
⃗
=
1
2
(
m
−
x
)
2
L
D
=
1
2
(
m
a
x
{
0
,
m
−
D
W
}
)
2
\vec{F}=\vec{m}-\vec{x}\\ 将锚点设为势能零点: E=0-\int\vec{F}d\vec{x}=\frac 1 2 (m-x)^2\\ L_D=\frac 1 2 (max\{0,m-D_W\})^2
F=m−x将锚点设为势能零点:E=0−∫Fdx=21(m−x)2LD=21(max{0,m−DW})2
原定义:
L
(
W
,
Y
,
X
1
⃗
,
X
2
⃗
)
=
(
1
−
Y
)
D
W
2
+
Y
⋅
1
2
(
m
a
x
{
0
,
m
−
D
W
}
)
2
L(W,Y,\vec{X_1},\vec{X_2})=(1-Y)D_W^2+Y\cdot \frac 1 2 (max\{0,m-D_W\})^2\\
L(W,Y,X1,X2)=(1−Y)DW2+Y⋅21(max{0,m−DW})2
{
当
Y
=
0
,调整参数最小化
D
W
(
X
1
⃗
,
X
2
⃗
)
当
Y
=
1
,设二者向量最大距离为
m
\begin{cases} 当Y=0,调整参数最小化 D_W(\vec{X_1},\vec{X_2}) \\ 当Y=1,设二者向量最大距离为m \end{cases}\\
{当Y=0,调整参数最小化DW(X1,X2)当Y=1,设二者向量最大距离为m
{
如果
D
W
(
X
1
⃗
,
X
2
⃗
)
<
m
,
则增大两者距离到
m
;
如果
D
W
(
X
1
⃗
,
X
2
⃗
)
≥
m
,则不做优化。
\begin{cases}如果 D_W(\vec{X_1},\vec{X_2})<m , 则增大两者距离到m;\\ 如果 D_W(\vec{X_1},\vec{X_2})\geq m ,则不做优化。\end{cases}
{如果DW(X1,X2)<m,则增大两者距离到m;如果DW(X1,X2)≥m,则不做优化。
效果就是:
Paper Waitting Read
一些常使用的Constrastive Loss
Triplet Loss:
L
=
m
a
x
{
d
(
x
,
x
+
)
−
d
(
x
,
x
−
)
+
α
,
0
}
L=max\{d(x,x^+)-d(x,x^-)+\alpha,0\}\\
L=max{d(x,x+)−d(x,x−)+α,0}
NCE Loss:
之前从向量空间考虑,NCE从概率角度考虑【原证明为贝叶斯派的证法】,NCE是对于得分函数的估计,那也就是说,是对于你空间距离分配的合理性进行估计。
总之NCE通过对比噪声样本与含噪样本,从而推断真实分布。
InfoNCE Loss 互信息:
I
(
x
,
c
)
=
∑
x
∑
c
p
(
x
,
c
)
l
o
g
p
(
x
,
c
)
p
(
x
)
p
(
c
)
=
∑
x
,
c
p
(
x
,
c
)
l
o
g
p
(
x
∣
c
)
p
(
x
)
I(x,c)=\sum_x\sum_c p(x,c)log\frac{p(x,c) }{p(x)p(c) } =\sum_{x,c}p(x,c)log\frac{p(x|c)}{p(x)}\\
I(x,c)=x∑c∑p(x,c)logp(x)p(c)p(x,c)=x,c∑p(x,c)logp(x)p(x∣c)
- 互信息上界估计:减少互信息,即VAE的目标。
- 互信息下界估计:增加互信息,即对比学习(CL)的目标。【后来也有CLUB上界估计和下界估计一起使用的对比学习。】
最关键的问题:如何构建正实例对和负实例对?
Paper
CPC
很多时候,很多数据维度高、label相对少,我们并不希望浪费掉没有label的那部分data。所以在label少的时候,可以利用无监督学习帮助我们学到数据本身的高级信息,从而对下游任务有很大的帮助。
Contrastive Predictive Coding(CPC) 这篇文章就提出以下方法:
- 将高维数据压缩到更紧凑的隐空间中,在其中条件预测更容易建模。
- 用自回归模型在隐空间中预测未来步骤。
- 依靠NCE来计算损失函数(和学习词嵌入方式类似),从而可以对整个模型进行端到端的训练。
- 对于多模态的数据有可以学到高级信息。
可以利用一定窗口内的 x t x_{t} xt 和 x t + k x_{t+k} xt+k作为正实例对,并从输入序列之中随机采样一个输入作为 x t ∗ x_{t*} xt∗ 负实例。
- 随机采样作为负样本,这个思想很关键!!!
给定声音序列上下文 c_t ,由此我们推断预测 x_{t+k} 位置上的声音信号。题目假设,声音序列全程伴随有噪音。
为了将噪音序列与声音序列尽可能的分离编码,这里就随机采样获得 x_{t*} 代替 x_{t+k} 位置信号,作为负样本进行对比学习。
- 意思就是,原本t+k是正常的数据,但是这是个序列,t是一个窗口,所以在序列有正常的样本,也有异常的样本,但是拿到的数据一般是正常的数据多,异常的数据(噪声)少(但又非常的关键),那这样的话正负样本比例失调(不平衡)而且也学不到正常样本的本质
- 所以,利用噪声的思想,把正常的样本加上随机的噪声作为负样本,这样来学正样本的规律和本质。也就是说,负样本我并不关心,它只是一个参照一个背景板,让模型去学正样本的本质规律
回到这个例子:
首先我们在原信号上选取一些时间窗口,对每一个窗口,通过encoder g e n c g_{enc} genc ,得到表示向量 z t z_t zt 。
z t z_t zt 通过自回归模型: g a r g_{ar} gar ,从而生成上下文隐变量 c t c_t ct。
然后通过Bi-linear:
- 采用 c t c_t ct 和 z t + k z_{t+k} zt+k 从而能够压缩高维数据,并且计算 c t c_t ct 和 z t + k z_{t+k} zt+k 的未来值是否符合
f k ( x t + k , c t ) = exp ( z t + k T ( W k c t ) ) f_k(x_{t+k},c_t)=\exp(z^T_{t+k}(W_kc_t))\\ fk(xt+k,ct)=exp(zt+kT(Wkct))
SimCLR
A Simple Framework for Contrastive Learning of Visual Representations
simCLR背后的想法非常简单:
- 视觉表征对于同一目标不同视角的输入都应具有不变性。
simCLR对输入的图片进行数据增强,以此来模拟图片不同视角下的输入。之后采用对比损失最大化相同目标在不同数据增强下的相似度,并最小化同类目标之间的相似度。
simCLR的架构由两个相同的网络模块组成。对于每一个输入网络的minibatch:
- 对mini batch中每张输入的图片进行两次随机数据增强(随机剪裁、滤镜、颜色过滤、灰度化等)来得到图片两种不同的视角;
- 将得到的两个表征送入两个卷积编码器(如resnet)获得抽象表示,之后对这些表示形式应用非线性变换进行投影得到投影表示;
- 使用余弦相似度来度量投影的相似度。
文章参考链接(综述):https://zhuanlan.zhihu.com/p/346686467
Papar: https://zhuanlan.zhihu.com/p/363900943
代码实战:main_linear
- Resnet + Classifier + CELoss
def set_model(opt):
model = SupConResNet(name=opt.model)
criterion = torch.nn.CrossEntropyLoss()
classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
ckpt = torch.load(opt.ckpt, map_location='cpu')
state_dict = ckpt['model']
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
model.encoder = torch.nn.DataParallel(model.encoder)
else:
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
model = model.cuda()
classifier = classifier.cuda()
criterion = criterion.cuda()
cudnn.benchmark = True
model.load_state_dict(state_dict)
else:
raise NotImplementedError('This code requires GPU')
return model, classifier, criterion
def train(train_loader, model, classifier, criterion, optimizer, epoch, opt):
"""one epoch training"""
model.eval()
classifier.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
end = time.time()
for idx, (images, labels) in enumerate(train_loader):
data_time.update(time.time() - end)
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
bsz = labels.shape[0]
# warm-up learning rate
warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
# compute loss
with torch.no_grad():
features = model.encoder(images)
output = classifier(features.detach())
loss = criterion(output, labels)
# update metric
losses.update(loss.item(), bsz)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
top1.update(acc1[0], bsz)
# SGD
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print info
if (idx + 1) % opt.print_freq == 0:
print('Train: [{0}][{1}/{2}]\t'
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
'loss {loss.val:.3f} ({loss.avg:.3f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, idx + 1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))
sys.stdout.flush()
return losses.avg, top1.avg
main_supcon
def set_model(opt):
model = SupConResNet(name=opt.model)
criterion = SupConLoss(temperature=opt.temp)
# enable synchronized Batch Normalization
if opt.syncBN:
model = apex.parallel.convert_syncbn_model(model)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
model.encoder = torch.nn.DataParallel(model.encoder)
model = model.cuda()
criterion = criterion.cuda()
cudnn.benchmark = True
return model, criterion
def train(train_loader, model, criterion, optimizer, epoch, opt):
"""one epoch training"""
model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
for idx, (images, labels) in enumerate(train_loader):
data_time.update(time.time() - end)
images = torch.cat([images[0], images[1]], dim=0)
if torch.cuda.is_available():
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
bsz = labels.shape[0]
# warm-up learning rate
warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
# compute loss
features = model(images)
# 使用torch.cat函数将切分后的两个子特征f1和f2在第一个维度上进行拼接,即将它们作为两个
# 通道(unsqueeze(1))拼接在一起,得到最终的特征features
f1, f2 = torch.split(features, [bsz, bsz], dim=0)
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
if opt.method == 'SupCon':
loss = criterion(features, labels)
elif opt.method == 'SimCLR':
loss = criterion(features)
else:
raise ValueError('contrastive method not supported: {}'.
format(opt.method))
# update metric
losses.update(loss.item(), bsz)
# SGD
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# print info
if (idx + 1) % opt.print_freq == 0:
print('Train: [{0}][{1}/{2}]\t'
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
epoch, idx + 1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses))
sys.stdout.flush()
return losses.avg
- main函数
def main():
opt = parse_option()
# build data loader
train_loader = set_loader(opt)
# build model and criterion
model, criterion = set_model(opt)
# build optimizer
optimizer = set_optimizer(opt, model)
# tensorboard
logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
# training routine
for epoch in range(1, opt.epochs + 1):
adjust_learning_rate(opt, optimizer, epoch)
# train for one epoch
time1 = time.time()
loss = train(train_loader, model, criterion, optimizer, epoch, opt)
time2 = time.time()
print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
# tensorboard logger
logger.log_value('loss', loss, epoch)
logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
if epoch % opt.save_freq == 0:
save_file = os.path.join(
opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
save_model(model, optimizer, opt, epoch, save_file)
# save the last model
save_file = os.path.join(
opt.save_folder, 'last.pth')
save_model(model, optimizer, opt, opt.epochs, save_file)
写的很好的utils
def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
if args.warm and epoch <= args.warm_epochs:
p = (batch_id + (epoch - 1) * total_batches) / \
(args.warm_epochs * total_batches)
lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def adjust_learning_rate(args, optimizer, epoch):
lr = args.learning_rate
if args.cosine:
eta_min = lr * (args.lr_decay_rate ** 3)
lr = eta_min + (lr - eta_min) * (
1 + math.cos(math.pi * epoch / args.epochs)) / 2
else:
steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
if steps > 0:
lr = lr * (args.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def set_optimizer(opt, model):
optimizer = optim.SGD(model.parameters(),
lr=opt.learning_rate,
momentum=opt.momentum,
weight_decay=opt.weight_decay)
return optimizer
def save_model(model, optimizer, opt, epoch, save_file):
print('==> Saving...')
state = {
'opt': opt,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
}
torch.save(state, save_file)
del state
Loss
对比损失:Supervised Contrastive Loss(监督对比损失)是一种在监督对比学习中使用的损失函数。它旨在学习既具有区分性又具有对同一类别内变化具有不变性的表示。
监督对比学习的目标是最大化正样本对(同一类别的样本)的一致性,并最小化负样本对(不同类别的样本)的一致性。监督对比损失通过鼓励正样本对的表示在嵌入空间中更加接近,同时将负样本对的表示推开来实现这一目标。
- If both
labels
andmask
are None, it degenerates to SimCLR unsupervised loss:
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
Net
model_dict = {
'resnet18': [resnet18, 512],
'resnet34': [resnet34, 512],
'resnet50': [resnet50, 2048],
'resnet101': [resnet101, 2048],
}
class SupConResNet(nn.Module):
"""backbone + projection head"""
def __init__(self, name='resnet50', head='mlp', feat_dim=128):
super(SupConResNet, self).__init__()
model_fun, dim_in = model_dict[name]
self.encoder = model_fun()
if head == 'linear':
self.head = nn.Linear(dim_in, feat_dim)
elif head == 'mlp':
self.head = nn.Sequential(
nn.Linear(dim_in, dim_in),
nn.ReLU(inplace=True),
nn.Linear(dim_in, feat_dim)
)
else:
raise NotImplementedError(
'head not supported: {}'.format(head))
def forward(self, x):
feat = self.encoder(x)
feat = F.normalize(self.head(feat), dim=1)
return feat
class SupCEResNet(nn.Module):
"""encoder + classifier"""
def __init__(self, name='resnet50', num_classes=10):
super(SupCEResNet, self).__init__()
model_fun, dim_in = model_dict[name]
self.encoder = model_fun()
self.fc = nn.Linear(dim_in, num_classes)
def forward(self, x):
return self.fc(self.encoder(x))