CenterLoss在MNIST上的实现

MNIST特征提取解释图像识别之CenterLoss

一、提出问题

       在图像识别中,一个很关键的要素就是图像中提取出来的特征,它关乎着图像识别的精准度。而通常用的softmax输出函数提取到的特征之间往往接的很紧,无太大的明显界限。在根据这些特征做识别的时候会出现模拟两可的情况,那么怎么让提取到的特征之间差异性更大从而提高识别的正确率就成了图像识别的一个重大问题。

二、解决办法:

      有研究就提出了解决问题的方法:减小类内聚,增大类间距,于是就有了后面的CenterLoss和ArcLoss
  CenterLoss是减小类内聚,间接增大类间距;ArcLoss直接增大类间距

1、CenterLoss公式

在这里插入图片描述

2、CenterLoss原理及效果

  它的目的是给每个类别的特征加一个中心点,然后使这一类别的特征点与它的中心的距离总和作为一个损失,然后去优化这个损失,使他们彼此无限靠近。从理论层面上讲,当学习到一定程度后,每个类别的特征会集中为一个点上,但从实际上说,这几乎是不太可能的,只能说接近于重叠在一个点。

如图1为 (log_softmax + NLLLoss)+Adam 输出的特征图

在这里插入图片描述

如图2为 CenterLoss 的原理

在这里插入图片描述

三、最终效果:
如图3为 (log_softmax + NLLLoss) + CenterLoss +Adam 的效果,网络中使用BacthNorm,且输出层bias=False,Centerloss中也对输入特征进行了Normalize

在这里插入图片描述

四、附:
  当网络不加Bacthnorm,Centerloss中对输入特征不做normalize时,训练将会很费时,而且效果也不是很理想。如下图4即为(log_softmax + NLLLoss) + CenterLoss + Adam,网络不加Bacthnorm,Centerloss未做normalize时的效果。

在这里插入图片描述

五、源码:
class CenterLoss(nn.Module):
	def __init__(self, cls_num, feature_num):
		"""
		:param cls_num: 类别数量
		:param feature_num: 特征维度
		"""
		super().__init__()
		self.cls_num = cls_num

		# 随机10个center
		self.center = nn.Parameter(torch.randn(cls_num, feature_num), requires_grad=True)

	def forward(self, feature, _target):
		"""
		:param feature: 特征输入
		:param _target: 标签输入
		:return: 中心损失值
		"""
		feature = F.normalize(feature)				# 对特征做归一化

		# 将center广播成特征点那么多个,每一个特征对应一个center
		centre = self.center.cuda().index_select(dim=0, index=_target.long())

		# 统计每个类别有多少的数据
		counter = torch.histc(_target, bins=self.cls_num, min=0, max=self.cls_num-1)
		# 将每个类别的统计数量广播,每个数据对应一个该类的总数,好做计算
		count = counter[_target.long()]
		centre_dis = feature - centre				# 做差,每个特征到它中心点的距离
		pow_ = torch.pow(centre_dis, 2)				# 平方
		sum_1 = torch.sum(pow_, dim=1)				# 横向求和,每个类别的距离总和
		dis_ = torch.div(sum_1, count.float())		# 类别差,每个类别的差除以该类的总量,得到该类均差
		# sqrt_ = torch.sqrt_(dis_)					# 开方
		sum_2 = torch.sum(dis_)						# 求总差,所有类别的差
		res = sum_2 / 2.0							# 乘:lambda / 2,
		return res
  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
RPCA(Robust Principal Component Analysis)是一种用于去除数据中噪声和异常值的技术。在MNIST数据集中,我们可以将RPCA用于去除图像中的噪声和异常值,从而提高图像分类的精度。 具体步骤如下: 1. 加载MNIST数据集并进行预处理。 2. 将MNIST数据集中的每个图像视为矩阵,并将这些矩阵按行展开成向量。 3. 对这些向量进行RPCA分解,得到低秩和稀疏矩阵。 4. 将低秩矩阵作为新的图像数据集,并使用机器学习算法进行训练和测试。 可以使用Python中的scikit-learn库来实现RPCA。具体代码实现如下: ```python from sklearn.decomposition import PCA from sklearn.linear_model import Lasso import numpy as np from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split # 加载MNIST数据集 mnist = fetch_openml('mnist_784') X, y = mnist.data / 255., mnist.target # 将MNIST数据集中的每个图像视为矩阵,并将这些矩阵按行展开成向量 X = np.array([np.reshape(x, (28, 28)) for x in X]) X = X.reshape(X.shape[0], -1) # 对这些向量进行RPCA分解,得到低秩和稀疏矩阵 pca = PCA(n_components=20) X_pca = pca.fit_transform(X) clf = Lasso(alpha=0.1) clf.fit(X_pca.T, X.T) X_sparse = clf.coef_ # 将低秩矩阵作为新的图像数据集,并使用机器学习算法进行训练和测试 X_new = np.dot(X_pca, X_sparse).reshape(X.shape[0], 28, 28) X_train, X_test, y_train, y_test = train_test_split(X_new, y, test_size=0.2, random_state=42) # 在新的图像数据集上使用机器学习算法进行分类 # ... ``` 这里我们使用PCA将原始图像数据集降维到20维,然后使用Lasso进行RPCA分解,得到低秩和稀疏矩阵。最后,将低秩矩阵作为新的图像数据集,使用机器学习算法进行训练和测试。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值