时间序列知识蒸馏论文——Time Series Prediction of Battery EOL with Intermediate State Knowledge Distillation

“利用中间状态知识蒸馏进行电池 EOL 时间序列预测”

非原文翻译,是个人阅读记录。

摘要

这篇论文提出了一种新的知识蒸馏方法,称为中间状态知识蒸馏(Intermediate-State Knowledge Distillation,ISKD),用于时间序列模型的压缩。知识蒸馏是一种训练学生模型的方法,使其能够模仿教师模型的输出。然而,仅使用教师模型的最终输出可能不足以训练学生模型。为此,本文提出了ISKD,通过在时间序列预测中加入模型的中间状态作为知识蒸馏的指导。实验表明,该方法在电池寿命预测任务中,相比传统的知识蒸馏方法,误差降低了最多1.04%,并减少了模型参数数量和延迟。

背景

时间序列预测利用过去的序列数据来预测未来的数值,广泛应用于电池寿命预测等领域。电池的健康状态(State of Health,SoH)和寿命终止点(End of Life,EOL)是关键指标,电池的SoH是指当前最大容量与初始设计最大容量的比值。预测SoH的变化有助于确定电池的EOL和剩余使用寿命(Remaining Useful Life,RUL)。该研究使用NASA AMES提供的电池数据集,基于LSTM模型进行SoH预测,并通过知识蒸馏技术对模型进行压缩。

方法

基线模型

论文使用了三种模型作为基线:

  1. FLSTM:一个两层的LSTM模型,仅在前向方向上计算。
  2. BiLSTM:双向LSTM模型,结构与FLSTM相同,但在前向和后向两个方向上计算。
  3. MBiLSTM:包含MLP层和BiLSTM的组合,类似于用于异常检测任务的VAE-LSTM模型。

所有模型的输入均为包含七个因素(循环、测量电压、测量电流、负载电压、负载电流、时间、SoH)的向量。

中间状态知识蒸馏(ISKD)

ISKD是为了改进传统知识蒸馏在时间序列预测任务中的效果而提出的。传统知识蒸馏主要关注模型最终输出的知识传递,而ISKD进一步利用教师模型中间层的状态信息,帮助学生模型更好地学习教师模型的知识。

1. 中间状态知识蒸馏的核心思想

  • 传统知识蒸馏:通常只关注教师模型和学生模型的最终输出,通过最小化两者的差异,使学生模型学习到教师模型的知识。
  • 中间状态知识蒸馏(ISKD):不仅考虑模型的最终输出,还引入模型中间层的状态,作为蒸馏过程中要学习的目标。具体来说,学生模型学习教师模型的中间状态(隐藏层状态和细胞状态),从而捕获教师模型中间层提取的特征。

2. ISKD的模型架构

  • 模型选择:论文中主要使用了FLSTM、BiLSTM和MBiLSTM三种基线模型。其中,MBiLSTM包含一个多层感知机(MLP)层和双向LSTM层。论文特别关注MBiLSTM的中间状态,将其分为三部分:S1、S2和S3,分别代表LSTM的隐藏状态和细胞状态。
    • S1:LSTM第一层的隐藏状态。
    • S2:LSTM最后一层的细胞状态。
    • S3:模型的最终输出。

3. ISKD的实现

  • 损失函数:在ISKD中,损失函数由三部分组成,以确保学生模型学习到教师模型的中间状态和最终输出:
    L = α ⋅ M S E ( y t , y s ) + β ⋅ M S E ( y , y s ) + γ ⋅ M S E ( S t , S s ) L = \alpha \cdot MSE(y_t, y_s) + \beta \cdot MSE(y, y_s) + \gamma \cdot MSE(S_t, S_s) L=αMSE(yt,ys)+βMSE(y,ys)+γMSE(St,Ss)
    其中:

    • y t y_t yt y s y_s ys 分别是教师模型和学生模型的最终输出。
    • y y y 是真实标签,学生模型需要拟合的目标。
    • S t S_t St S s S_s Ss 分别是教师模型和学生模型的中间状态(例如S1或S2)。
    • α \alpha α, β \beta β, γ \gamma γ 是权重因子,用于调整损失函数中各部分的影响,满足 α + β + γ = 1 \alpha + \beta + \gamma = 1 α+β+γ=1
  • 选择中间状态:论文中测试了两种中间状态S1和S2:

    • S1:LSTM第一层的隐藏状态。通过损失函数使学生模型学习教师模型第一层的隐藏状态特征。
    • S2:LSTM最后一层的细胞状态。损失函数指导学生模型学习教师模型最后一层的细胞状态特征。
  • 损失函数的各部分解释

    1. M S E ( y t , y s ) MSE(y_t, y_s) MSE(yt,ys):教师模型和学生模型最终输出之间的均方误差,鼓励学生模型模仿教师模型的最终输出。
    2. M S E ( y , y s ) MSE(y, y_s) MSE(y,ys):学生模型的输出与真实标签之间的均方误差,确保学生模型的输出逼近真实值。
    3. M S E ( S t , S s ) MSE(S_t, S_s) MSE(St,Ss):教师模型和学生模型中间状态之间的均方误差,是ISKD的核心,指导学生模型学习教师模型的中间状态。

4. 蒸馏过程

  • 训练过程
    • 教师模型:首先训练一个较大的教师模型(例如窗口长度为10或15的MBiLSTM),这个模型可以较好地捕获数据的时序特征,并生成中间状态S1、S2和最终输出S3。
    • 学生模型:定义一个较小的学生模型(例如窗口长度较短的MBiLSTM)。在训练学生模型时,不仅使用教师模型的最终输出作为学习目标,还将教师模型的中间状态作为额外的学习目标,通过损失函数中的中间状态项引导学生模型学习。
  • 选择中间状态:论文分别测试了使用S1和S2作为中间状态的效果,发现S2(最后一层的细胞状态)对学生模型性能提升更明显。

实验

  • 数据集:使用NASA AMES提供的四个电池数据集,其中三个用于训练,一个用于测试。
  • 实验设计:从电池充放电30次后开始预测,通过观察电池在30次循环后的状态,预测未来的趋势。将模型的窗口长度设为不同值,较大的窗口长度定义为教师模型,较小的窗口长度定义为学生模型。
  • 性能指标:使用R2得分、RMSE、EOL误差等指标来评估模型性能。

结果

  1. 基线模型:在不同窗口长度下评估了三种基线模型,最佳性能出现在窗口长度为10或15时,将其定义为教师模型。
  2. 知识蒸馏:使用教师模型对学生模型进行知识蒸馏。相比传统知识蒸馏方法,ISKD方法在加入中间状态(S1或S2)后,可以显著提升学生模型的性能,尤其是使用S2中间状态时性能提升最为明显。
  3. 延迟与模型压缩:通过ISKD减少了模型参数数量和延迟,提升了在低规格设备上的适用性。

结论

本文提出的中间状态知识蒸馏(ISKD)在时间序列模型中引入了中间层的知识蒸馏,使得学生模型在性能上接近甚至超过了传统知识蒸馏方法。此外,该方法还有效地降低了模型参数数量和计算延迟。实验结果表明,ISKD在电池寿命预测中取得了显著的性能提升。


Cite:Kim, Geunsu et al. “Time Series Prediction of Battery EOL with Intermediate State Knowledge Distillation.” 2022 IEEE International Conference on Consumer Electronics-Asia (ICCE-Asia) (2022): 1-4.

未找到源码,我复现的部分代码如下,完整代码请联系:

MBiLSTM模型
class MBiLSTM(nn.Module):
    def __init__(self, input_size=7, hidden_size=64, num_layers=2):
        super(MBiLSTM, self).__init__()
        self.mlp = nn.Linear(input_size, hidden_size)
        self.bilstm = nn.LSTM(hidden_size, hidden_size, num_layers, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(hidden_size * 2, 1)  # *2 because of bidirectional

    def forward(self, x):
        x = self.mlp(x)
        out, (hn, cn) = self.bilstm(x)
        s1 = out[:, :, :out.size(2) // 2]  # Extract S1 (hidden state)
        s2 = cn[-1]  # Extract S2 (last cell state)
        out = self.fc(out[:, -1, :])
        return out, s1, s2  # Return final output and intermediate states
def train_student_with_iskd(student_model, teacher_model, dataloader, epochs, learning_rate, alpha=0.3, beta=0.6, gamma=0.1):
    optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
    
    teacher_model.eval()  # Freeze teacher model
    for epoch in range(epochs):
        student_model.train()
        for sequences, targets in dataloader:
            sequences = sequences.unsqueeze(-1)  # Reshape for input size
            targets = targets.unsqueeze(-1)
            
            with torch.no_grad():
                teacher_outputs, teacher_s1, teacher_s2 = teacher_model(sequences)
            
            student_outputs, student_s1, student_s2 = student_model(sequences)
            
            # 使用S2作为中间状态进行知识蒸馏
            loss = iskd_loss(student_outputs, targets, teacher_outputs, student_s2, teacher_s2, alpha, beta, gamma)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机智的小神仙儿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值