知识蒸馏,使用随机输入预测作为伪标签训练新模型—— 一次不算失败的小实验

想法的开始

今天在学习知识蒸馏这方面的知识的时候,突发奇想,如果无法访问到原数据域,只有一个训练好的模型(比如很多医学图像数据,隐私保密),能否使用这个模型再得到一个新的模型(比如层数更少,规模更小的模型),用于部署到一些计算能力有限的边缘设备。

我的想法是类似 gan 中生成器一样,用随机噪声作为模型输入,得到模型输出的结果标签,作为软伪标签

然后再用随机噪声和生成的标签去训练新的模型

让学生模型去学习教师模型学习到的一些特征,而不是去学习,数据和标签之间过拟合的部分

实验的初始

于是我使用了 MNIST 数据集展开了一个小实验

先创建了两个结构相同的两层卷积网络加上全连接输出,分别命名为 model_studentmodel_teacher ,先正常使用 MNIST 数据集训练model_teacher
步骤一

随机生成28X28大小的随机噪声,输入到训练好的 model_teacher 中,产生一个维度为10的输出,经过 softmax 层,输出作为这个随机噪声图像的软伪标签,然后用这一对去训练新的 model_student

使用噪声训练好 model_student 后,再使用 MNIST 数据集去测试这个模型
在这里插入图片描述
直觉上认为可行

实验与结果分析

前三次实验以下条件相同:

  • 教师模型和学生模型均为简易的卷积神经网络,模型结构代码可见最后。
  • 先使用 teacher 模型再 MNIST 数据集上训练3个epoch,并在测试集上验证得到准确率为99%。

第一次实验

使用学生模型在10000张随机图像上训练了5个epoch,训练损失几乎不下降,最终再测试集上验证准确率为31%,相当于学习到一部分特征,但是很少

先下结论:模型结构相当简单,数据特征也很简单,但结果很差,目前看来看来不可行
原因可能有很多,给出了两条猜想:

  • 我是每次训练都生成新的噪声,这是为了不让模型学习这些噪声图像,而是学习其中的特征,但是训练损失一直无法下降,使用固定一批噪声图像能否解决这个问题
  • 训练次数较少,超参数的选择是否正确,我从训练5个epoch,再训练10个epoch,准确率提升了5%

而下面这三条分析是 AI 给我的回答

  1. 噪声输入的随机性:随机噪声不包含任何有意义的信息或结构。因此,模型对这些噪声的预测结果通常是无意义的,可能只是基于模型的内在偏差或随机选择的输出。将这些预测结果作为标签训练新模型可能导致新模型学到不相关的、甚至是错误的信息。

  2. 标签的不准确性:使用随机噪声产生的标签,实际上是基于模型在无意义输入上的预测,这些标签没有实际的语义或正确性。这种标签会严重影响新模型的训练效果,使其无法学到有用的特征。

  3. 模型的泛化能力差:由于输入的是噪声,训练出的新模型可能会过拟合于这些噪声生成的标签,而无法在实际数据上表现良好。

第二次实验

加大训练力度,训练的50个epoch,每个epoch有50000张图像,对于生成的随机噪声图像有两种策略:
- 第一种:每个epoch都生成新的随机噪声图像,即总共2500000张随机图像,最终再测试集上准确率为71.44%
- 第二种:总共使用50000张噪声图像,每个epoch的图像相同,即总共50000张图像,最终测试集准确率为49.01%
以下为我对本次实验结果的初步第一印象分析(没有参考价值),也与我的想法有一点符合(哭笑):

  • 每次epoch都使用相同的图像,会使得模型去学习这些噪声图像与标签的联系,而不是去学习其中特征,导致对这些图像过拟合,而无法学习到原数据中的特征
  • 每次都生成新的噪声图像,使得模型无法去模拟这些噪声图像,而使得模型学习到了部分特征和标签之间的关系,再测试集上进行验证时,正确率更高

第三次实验

由于第二次实验没有保存模型参数,无法进行后续的猜想,产生了第三次实验
实验条件和第二次实验相同,存在的差别仅仅为,生成的噪声图像以及网络参数初始化,随机产生导致的差距
- 第一种方案:每个epoch都生成新的随机噪声图像,最终再测试集上准确率为87.74%(第二次实验为71.44%)
- 第二种方案:总共使用50000张噪声图像,即每个epoch的图像相同,最终测试集准确率为85.34%(第二次实验为49.01%)

两次相同方案训练,在数据集上测试差距巨大,方案一差距伪16%,方案二差距为36%,而两者区别仅为初始的随机参数不同

实验二和实验三的条件可以认为相同,但是实验结果差距非常大,导致我有一点懵了

仅凭这几次实验,就做出结论很显然是不合理的,更多实验也只有等待后期完善了,这个小想法也不知道是成功还是失败了

代码

自定义数据加载

法一:每个 epoch 使用不同的图像

class noise_dataset(torch.utils.data.Dataset):
    def __init__(self, transform, model):
        self.len = 50000
        # self.data = np.random.random(size=(50000, 28, 28)).astype(np.float32)
        self.transform = transform
        self.model = model

    def __getitem__(self, index):
    	# 每次训练都生成一张新的噪声图像
        image = np.random.random((28, 28)).astype(np.float32)
        image = Image.fromarray((image * 255).astype(np.uint8))  # 将 numpy 数组转换为 PIL 图像
        image = self.transform(image).to(device)
        with torch.no_grad():
            label = self.model(image.unsqueeze(0)).squeeze(0)  # 获取教师模型的预测
            label = F.softmax(label, dim=-1)
        return image, label

    def __len__(self):
        return self.len

法二:每个 epoch 使用相同的图像

class noise_dataset(torch.utils.data.Dataset):
    def __init__(self, transform, model):
        self.len = 50000
        self.data = np.random.random(size=(50000, 28, 28)).astype(np.float32)
        self.transform = transform
        self.model = model

    def __getitem__(self, index):
    	# 每次训练都生成一张新的噪声图像
        # image = np.random.random((28, 28)).astype(np.float32)
        image = self.data[index]
        image = Image.fromarray((image * 255).astype(np.uint8))  # 将 numpy 数组转换为 PIL 图像
        image = self.transform(image).to(device)
        with torch.no_grad():
            label = self.model(image.unsqueeze(0)).squeeze(0)  # 获取教师模型的预测
            label = F.softmax(label, dim=-1)
        return image, label

    def __len__(self):
        return self.len

模型结构

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2, bias=False)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 7 * 7, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

后续小测试

使用添加了噪声的 MNIST 数据集分别对学生和老师模型进行验证

如图为噪声程度,统一使用第 5 级别
噪声程度
差距很小,原教师模型和学生模型的差距都1%。

还等待完善探索吧

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值