论文翻译:How to Retrain Recommender System A Sequential Meta-Learning Method

论文翻译:How to Retrain Recommender System? A Sequential Meta-Learning Method

一、Abstract

实际的推荐系统需要周期性地进行再训练,以获得新的交互数据来更新模型。为了追求较高的模型逼真度,通常需要根据历史数据和新数据对模型进行再训练,因为它需要同时考虑长期和短期用户偏好。然而,一个完整的模型再训练可能非常耗时,而且内存开销很大,特别是在历史数据规模很大的时候。在本文中,我们研究了推荐系统的模型再培训机制,这是一个具有很高实用价值的课题,但在研究领域中探索相对较少。

我们的第一个想法是,根据历史数据对模型进行再训练是没有必要的,因为之前已经对模型进行过训练。然而,由于新数据的规模较小,包含的关于用户长期偏好的信息也较少,因此常规的对新数据的训练很容易导致过拟合和遗忘问题。为了解决这一困境,我们提出了一种新的训练方法,旨在通过学习转移过去的训练经验,从而在再训练时可以抛弃历史数据。具体来说,我们设计了一个基于神经网络的转换组件,它可以将旧模型转换为适合处理未来推荐问题的新模型。为了更好地学习好这个“转换组件”,我们要优化“未来绩效”。例如,下一个时间段的推荐准确度评估。我们的序列元学习(Sequential Meta-Learning SML)方法提供了一个通用的训练范例,适用于任何分类模型。我们基于矩阵分解进行了SML演示,并在两个真实数据集上进行了实验。实验结果表明,SML不仅能显著提高推荐速度,而且在推荐准确度方面优于全模型的训练,这证明了我们想法的可行性。(代码地址:https: //github.com/zyang1580/SML)

关键词:推荐、模型再训练、元学习

二、Introduction

在信息严重过载的Web2.0时代,推荐系统发挥着越来越重要的作用。推荐系统的关键技术是个性化模型,它根据用户与物品的历史交互信息来估计用户对物品的偏好[14,33]。既然用户不断与系统交互,不断收集新的交互数据,为用户偏好提供最新的依据。因此,使用新的交互数据对模型进行再培训,以提供及时的个性化,避免成为过时的模型[36],这是非常重要的。随着推荐模型的复杂性不断增加,以在线方式对模型进行实时更新在技术上具有挑战性,特别是对于那些表达能力强但计算成本昂贵的深度神经网络[13,26,43]。因此,行业中常见的做法是定期进行模型再培训,例如每天或每周。图1展示了模型再训练的过程。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QdsLeajJ-1599447810077)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200611200041258.png)]

直观地看,历史上的交互更多地反映了用户的长期(如内在)兴趣,而新收集的交互更多地反映了用户的短期偏好。到目前为止,根据数据利用情况,最广泛采用的再训练略有以下三种:

微调(Fine-tuning),仅基于新的交互作用更新模型[35,41]。这种方法在内存和时间上都很有效,因为只处理新数据。但是忽略了包含长期偏好信号的历史数据,容易造成[6]的过拟合和遗忘问题。

基于样本的再训练(Sample-based retraining),即对历史交互进行采样,并将其添加到新的交互中,形成训练数据[6,42]。如果期望采样的交互作用能保持长期的偏好信号,需要对其仔细选择以获得具有代表性的交互作用。在推荐准确度方面,由于采样[42]会造成信息损失,通常比使用全部的历史交互信息表现差。

全部再训练(Full retraining),即根据包括所有历史和新交互的整个数据对模型进行培训。虽然效果最好,但是要花费大量的资源和训练时间。

上述三种策略各有利弊,我们认为一个关键的局限性是,它们缺乏对再训练目标的明确优化。例如,再训练的模型应该在下一个时间阶段表现的更好才行。在实践中,下一个时间段的用户交互是当前模型泛化性能最重要的证据,通常用于模型选择或验证。因此,一种有效的再训练方法应该考虑到这个目标,并制定优化目标的再训练过程,这是一种比手工制作启发式选择数据示例更有原则的方法[6,35,40,42]。

在本文中,我们探讨了推荐模型再训练的中心主题,这是一个在行业推荐系统中具有很高实用价值的主题,但在研究中却很少受到关注。虽然基于全部历史信息的模型再训练方法提供了最高的保真度,但我们认为没有必要这样做。关键的原因是,在之前的训练中,已经训练了历史的交互作用,这意味着模型已经从历史数据中提取出了“知识”。如果有一种方法可以很好地保留知识并将其转移到关于新的交互训练中,我们应该能够保持与基于全部历史信息的模型再训练相同的绩效水平,即使我们在模型再训练中不使用历史数据。此外,如果知识转移器足够“聪明”,能够捕捉更多的模式,比如最近的数据更能反映近期的表现,我们甚至有机会提高基于全部历史信息的再训练模型效果。

最终,我们提出了一种新的再培训方法,主要考虑以下两方面:

(1)构建一个转换表达组件,将以前训练中获得的知识转移到新的交互训练中;(2)优化这个转换组件,以提高近期的推荐性能。

为了实现第一个目标,我们将转换组件设计为卷积神经网络(CNN),它将之前的模型参数作为常数输入,并将目前的模型作为可训练的参数输入。其合理性在于将以往训练中获得的知识浓缩在模型参数中,这样表达性神经网络就能够将知识提取出来,从而达到预期的目的。

为了达到第二个目标,除了对新收集到的交互信息进行正常的训练外,我们还对转换器 CNN进行了下一时间段的未来交互信息的训练。因此,CNN可以学习如何结合旧的参数和现在的参数,以预测不久之后的用户交互。

整个架构可以被视为一个元学习的示例[9]:每个时间段的训练是一项任务,将当前时间段的交互当做训练集,将下一个时间段的未来交互当做测试集,通过学习训练的历史任务,我们期望在测试集上表现更好。由于我们的元学习机制是在序列数据上操作的,我们将其命名为序列元学习((Sequential Meta-Learning SML)。

本文贡献:

  1. 强调了推荐系统再训练研究的重要性,并将推荐系统再训练过程作为一个可优化问题进行阐述;
  2. 我们提出了一个新的再训练方法:仅训练新交互却很高效;能够优化未来的推荐性能。
  3. 我们在两个真实数据集(Adressa news and Yelp business)上进行了实验,结果很好。

三、问题建模

在真实的推荐系统中,用户交互数据流是连续的。为了用最近的数据保持预测模型的新鲜度,一个通常的选择是定期对模型进行再训练。我们将数据表示为
D 0 , … D t D t + 1 {D_0,…D_t D_{t+1}} D0DtDt+1
其中Dt表示t时间段内新收集的数据。假设每次的再训练都是在收集Dt之后触发的。根据系统需求和实现能力,周期可以是任意长度的时间,例如,每天,每周,或者直到收集到一定数量的交互。

在t时间段的再训练中,系统可以访问之前所有的数据,即 {D0,…,Dt−1},以及新数据Dt。由于再训练的模型是为近期服务的,基于Dt+1——下一时间段收集的数据来判断其有效性是合理的。因此,我们将Dt+1上的推荐性能作为t期再训练的泛化目标。设t期再训练后的模型参数为Wt。我们将每一项再训练视为一项任务,并将其制定为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sXKCJuPg-1599447810081)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200611210828042.png)]

即,基于再训练时所有可访问的数据,再加上之前的再训练模型参数,进而得到一个新的模型参数,能够对近期数据Dt+1表现很好。工业上常用的方法是将Wt-1作为初始化参数进行全部的再训练。该方案实现简单,但是时间和资源消耗大,并且随着时间增加,计算需求更大。另一个问题是缺少优化,而这并不容易解决,因为直接使用Dt+1将导致信息泄露并使得泛化能力变差。

在本文中,我们的目标是利用新收集到的数据Dt加上先前的模型参数Wt-1,去获得一个好的再训练模型作为Dt+1的评估。因此我们将再训练模型重新制定为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qNT5Pwt7-1599447810085)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200611212721413.png)]

我们将其定义为任务τt。对于任务τ0,其之前的模型参数是随机初始化的。一个直接的解决方案是使用SGD去更新Dt和Wt-1,但是这很容易出现用户长期兴趣的遗忘问题。因为更新的次数越多,初始化的影响就越小。另外,该方案也缺少面向Dt+1的优化策略。

与标准元学习中对任务的定义不同[9,21],这里的任务自然的形成了一个序列:
τ 0 , . . . , τ t , τ t + 1 , . . . {τ_0, ...,τ_t,τ_{t+1}, ...} τ0,...,τt,τt+1,...
在线上测试中,只有τt被完成,我们才能进入τt+1。因此离线训练也应该遵循类似的序列训练方式,以确保该方法能够线上推广。最后,解决问题可视为一个元学习的例子,因为学习的目标是怎样把任务处理好(带着好的泛化能力去处理未来的任务),与简单的学习Dt上的模型参数相比,这是一个更高水平的问题。

四、Method

首先给出了解决任务τt的模型概述,其核心是设计一个转换组件,将旧模型Wt-1有效地转换为新模型Wt,然后详细阐述了转换组件的设计。接下来,我们讨论如何训练模型使其在当前数据Dt上具有良好的性能,并对未来的数据Dt+1有良好的泛化。最后,我们演示了如何矩阵分解上实例化我们的方法,其中最经典和代表性的一个模型就是协同过滤。

1、Method Overview

我们的目标是解决在方程2中定义的任务τt,它仅利用新数据Dt就能比肩甚至超过全部历史交互再训练的结果。我们认为过去的信息被存储在了参数Wt-1里面,另一个考量是我们的技术应当具有通用性而非仅适应某一个模型。

为此,我们设计了一个通用模型框架,如图2所示

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gafQ5QiA-1599447810089)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200612093944675.png)]

它有以下三个部分:1. Wt-1代表基于过去数据训练的推荐模型 2. W ^ t \hat{W}_t W^t代表需要从当前数据Dt学习的新推荐模型 3. Transfer会结合Wt-1和 W ^ t \hat{W}_t W^t形成新的推荐模型Wt,被用于服务下一个时间阶段的推荐。在第t个时间阶段,Wt-1被设置为常数,再训练过程主要包含以下两步:

  1. 获得 W ^ t \hat{W}_t W^t,其包含了来自Dt的有用信息,这一步可以通过优化标准推荐Loss函数 L r ( W ^ t ∣ D t ) L_r(\hat{W}_t|D_t) Lr(W^tDt)来实现

  2. 获得Wt,它是transfer的输出:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aRgUryLt-1599447810094)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200612095631367.png)]

    其中, f f f是transfer函数, Θ \Theta Θ是参数,Wt-1, W ^ t \hat{W}_t W^t是它的输入。

在这个框架中,Wt-1, W ^ t \hat{W}_t W^t可以是任何可微的推荐模型,只要它们结构相同(即,有相同的参数数目和语义)。只有转换部分需要仔细设计,这正是接下来要讲的我们的贡献所在。

2、Transfer Design

从功能性上来说,Wt-1, W ^ t \hat{W}_t W^t生成Wt,那么它们三者的shape应当是一致的,加权求和的操作可以很容易满足这点:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vPTqgoTH-1599447810096)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200612100649985.png)]

其中 α \alpha α是权重系数,可以预定义也可以通过学习得到。这个方法很容易解释,因为引入的参数很少,所以也很容易训练。但是表达能力有限,例如其不能说明不同维度的参数间的关系。

为了transfer的表现力,可以使用多层感知机(MLP):

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3WTRcsGd-1599447810098)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200612102840431.png)]

尽管MLP[19]具有普适的逼近效果,但在实际应用中可能难以很好地训练[1,13]。还有一个限制就是它不强调同维度参数间的相互作用,然而同维度参数的交互对于理解参数的演化很重要。举例来说,假设模型为矩阵分解,参数为用户embedding。然后差分 W ^ t \hat{W}_t W^t-Wt-1意味着参数变化可以捕获用户的兴趣转移。并且每一维的product W t − 1 ⊙ W t ^ W_{t-1}⊙\hat{W_t} Wt1Wt^表明在反映用户短期和长期的兴趣时维度的重要性。然而,MLP缺乏有效捕获这些模式的机制。

为此,我们设计的transfer不仅能够强调Wt-1和 W ^ t \hat{W}_t W^t在某个维度之间的关系,还能捕捉其在不同维度间的关系。受CNN在图像处理中捕捉区域特征的启发,我们设计了基于CNN的transfer,这个CNN结构可以在图2的绿盒中被发现,其包含了一个stack层,两个卷积层和一个用于输出的全连接层。

接下来我们详细的介绍CNN的设计。不失一般性的,我们将 W ^ t \hat{W}_t W^t和Wt-1当做行向量,将其表示成 w t − 1 w_{t-1} wt1 w t ^ \hat{w_t} wt^。其原来的形式可以是维度或是张量,这方便我们对合并两个模型进行维度操作。

Stack层:

这一层用于stacks w t − 1 w_{t-1} wt1 w t ^ \hat{w_t} wt^,并将element-wise product作为2维矩阵,它将作为“image”被送入卷积层进行处理。更具体的,我们将其公式化描述为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KXVfySoW-1599447810099)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200612104955463.png)]

w d o t w_{dot} wdot的分子用于捕捉当$w_{t-1} $ 发展到 w t ^ \hat{w_t} wt^时,维度是变大还是变小, w d o t w_{dot} wdot的分母是为了标准化,加一个小量 ϵ = 1 0 − 15 \epsilon=10^{-15} ϵ=1015是为了防止分母等于零。 H 0 H^0 H0的size是3 x d,其中d是 w t − 1 w_{t-1} wt1 w t ^ \hat{w_t} wt^的size。

Convolution层:

H 0 H^0 H0被送入两个级联的卷积层以进一步建模维度关系。因为第二层与第一次形式相似,因此我们重点描述一下第一层。设第一层卷积层有n1个vertical filters,每个都被标记为 F j ϵ R 3 x 1 F_j \epsilon R^{3x1} FjϵR3x1(j=1,…,n1是filters的序号)。 F j F_j Fj H 0 H^0 H0从第一列到最后一列划分开,对每个列向量进行操作:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7kF76GBO-1599447810101)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200613154816465.png)]

其中 H : , m 0 H^0_{:,m} H:,m0 H 0 H^0 H0的第m个列向量,<>表示向量做内积运算, H j , m 1 ϵ R H^1_{j,m}\epsilon R Hj,m1ϵR F j F_j Fj H : , m 0 H^0_{:,m} H:,m0上卷积的结果,GELU是高斯误差线性激活函数[17],其可以被视为ReLU的一个平滑变量,其梯度是负值。

注意,vertical filter Fj可以学习 w t ^ \hat{w_t} wt^ w t − 1 w_{t-1} wt1在相同维度的各种关系。例如,如果这个filter是[-1,1,0],它可以表达 w t ^ \hat{w_t} wt^ w t − 1 w_{t-1} wt1的差分关系;如果filter是[1,1,1],它可以获得 w t ^ \hat{w_t} wt^ w t − 1 w_{t-1} wt1都有很高正值的突出特征。我们使用一维filters而不是二维filters的另一个原因是, w t ^ \hat{w_t} wt^ w t − 1 w_{t-1} wt1的维度order并不是对于所有推荐模型都有意义,例如,如果我们改变了因子分解模型的embedding顺序,其模型的预测结果将不会改变。

第一个卷积层的输出 H 1 H^1 H1 是一个n1 x d的矩阵,她被送入第二层的n2个filters中,其中每个filter的大小是 n1 x 1,所以我们最终获得的分量 H 2 H^2 H2的大小是一个 n2 x d的矩阵。

Full-connected and output layers:

H 2 H^2 H2被送入一个全连接层(FC),以获取不同维度间的关系。我们先将 H 2 H^2 H2平化成一个大小是dn2的向量,然后将其送入全连接层:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kPkOCkF2-1599447810103)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200613161617465.png)]

其中, W f ϵ R ( d n 2 ) X d f W_f\epsilon R^{(dn2)Xd_f} WfϵR(dn2)Xdf b 1 ϵ R d f b_1\epsilon R^{d_f} b1ϵRdf分别是全连接层的权重矩阵和偏置向量。df是全连接层的size。向量z经过线性层的变化,输出新的参数向量 w t w_t wt:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lezeDYDw-1599447810105)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200613162219844.png)]

其中, W 0 ϵ R d f X d W_0 \epsilon R^{d_f X d} W0ϵRdfXd b 2 ϵ R d b_2 \epsilon R^d b2ϵRd分别是线性层的权重和偏置。最终, w t w_t wt变换为 W t W_t Wt,即再训练后的新的模型参数。

总结一下,transfer 的所有可训练参数为: Θ = ( F 1 , F 2 , W f , b 1 , W 0 , b 2 ) \Theta = ({F^1,F^2,W_f,b_1,W_0,b_2}) Θ=(F1,F2,Wf,b1,W0,b2),其中 F 1 ϵ R n 1 X 3 F^1 \epsilon R^{n_1X3} F1ϵRn1X3 F 2 ϵ R n 2 X n 1 F^2 \epsilon R^{n_2Xn_1} F2ϵRn2Xn1分别是第一卷积层和第二卷积层的filter。值得一提的是,我们可以将一个推荐模型的参数分为不同的组,并对每组应用一个单独的transfer网络。例如,矩阵分解模型有两组参数——user embedding和item Embedding,我们可以使用两个传输网络分别应对。(见4.4)

3、Sequential Training

我们现在考虑怎样去训练模型参数,包括transfer对于任务τt的输入 W t ^ \hat{W_t} Wt^,以及在所有人中使用的transfer的参数 Θ \Theta Θ。从功能上讲, Θ \Theta Θ将会从当前数据集Dt中获取推荐知识,而 Θ \Theta Θ则是将先前模型的Wt-1和 W t ^ \hat{W_t} Wt^相结合,以期望获得一个在未来数据集Dt+1上表现很好的transfer输出Wt。既然数据是序列输入,我们也应该按照相同的序列方式进行训练。即,在任务τt前应先解决任务τt-1,算法1显示了训练过程。接下来描述一下如何训练任务τt(3-11行),主要有以下两个步骤:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AbGHJ5y3-1599447810108)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200614163125797.png)]

第一步:学习transfer的输入 W t ^ \hat{W_t} Wt^

一个简单的方案是直接基于Dt的loss去学习。但是所得到的 W t ^ \hat{W_t} Wt^可能不适合作为transfer的输入。因为我们假设Wt-1, W t ^ \hat{W_t} Wt^,Wt在相同的空间中,也就是参数维度一致,其值是相同的scale range。为了解决这个问题,我们提出优化transfer在Dt上的输出,然后反向传播到 W t ^ \hat{W_t} Wt^,具体的loss公式如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OgxKAM6B-1599447810109)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200614164002480.png)]

其中 L 0 ( x ∣ D ) L_0(x|D) L0(xD)是数据Dt上的recommendation loss[15],或是pairwise loss[33],x是推荐模型的参数。 x = f Θ ( W t − 1 , W t ^ ) x=f_{\Theta}(W_{t-1},\hat{W_t}) x=fΘ(Wt1,Wt^)。在优化loss时, Θ \Theta Θ被视为常数不进行更新,所以只更新 W t ^ \hat{W_t} Wt^的参数。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-np8D60Jq-1599447810111)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200614164754887.png)]

在得到这个梯度后,我们可以用SGD或是Adam[22]去优化更新 W t ^ \hat{W_t} Wt^,通过这个方式,我们能够同时达到以下两个效果:1)从Dt中获取知识;2)使 W t ^ \hat{W_t} Wt^成为适合transfer的输入。

第二步:学习transfer的参数 Θ \Theta Θ

因为参数 Θ \Theta Θ在所有任务中共享,所以它可以捕获一些任务不变的模块。例如,哪些参数维度和用户短期兴趣更相关,并应该在组合Wt-1和 W t ^ \hat{W_t} Wt^时予以强调。总的目标是获得一个为下一时间段的推荐量身制作的patterns。因此,我们考虑对下一时间段的周期数据Dt+1进行优化。具体的目标函数如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A9XbvwAo-1599447810112)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200614170221538.png)]

注意,在第一步中获得的 W t ^ \hat{W_t} Wt^是参数 Θ \Theta Θ的函数。因此,在计算参数 Θ \Theta Θ的梯度时,将导致高阶梯度难以获取。因此,我们遵循一阶算法MAML[9],忽略对梯度影响小但是成本高的高阶梯度。在这一步中,通过将 W t ^ \hat{W_t} Wt^当做常数,我们将参数 Θ \Theta Θ的梯度计算为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xHileJM9-1599447810115)(C:\Users\X-i\AppData\Roaming\Typora\typora-user-images\image-20200614171406074.png)]

将以上两个步骤迭代,直到收敛或是达到最大迭代次数(第四行)。如算法1所示,由于下一个时间段的数据DT+1在训练中不可用,所以在上一个时间段T中不执行参数 Θ \Theta Θ的更新。注意,我们可以在 D t t = 0 T {D_t}^T_{t=0} Dtt=0T上运行多个这样的序列训练。然而我们根据经验发现一个训练足以。

值得一提的是,serving(evaluation)阶段的评估略有不同。算法2显示了我们如何对新收集的数据Dt+1进行测试或是验证的模型评估。首先我们使用它去测试要在t+1阶段服务的模型Wt;然后我们需要用Dt+1更新 Θ \Theta Θ W t + 1 ^ \hat{W_{t+1}} Wt+1^,以获取下一个时间阶段的Wt+1( W t + 1 = f Θ ( W t , W t + 1 ^ ) W_{t+1}=f_{\Theta}(W_t,\hat{W_{t+1}}) Wt+1=fΘ(Wt,Wt+1^))。如3-8行所展示,我们首先迭代更新 Θ \Theta Θ W t ^ \hat{W_t} Wt^,这与训练阶段相同。当满足停止条件时,用更新后的 Θ \Theta Θ去更新 W t + 1 ^ \hat{W_{t+1}} Wt+1^,最终送入 f Θ ( W t , W t + 1 ^ ) f_{\Theta}(W_t,\hat{W_{t+1}}) fΘ(Wt,Wt+1^)来获得Wt+1。

W t + 1 = f Θ ( W t , W t + 1 ^ ) W_{t+1}=f_{\Theta}(W_t,\hat{W_{t+1}}) Wt+1=fΘ(Wt,Wt+1^))。如3-8行所展示,我们首先迭代更新 Θ \Theta Θ W t ^ \hat{W_t} Wt^,这与训练阶段相同。当满足停止条件时,用更新后的 Θ \Theta Θ去更新 W t + 1 ^ \hat{W_{t+1}} Wt+1^,最终送入 f Θ ( W t , W t + 1 ^ ) f_{\Theta}(W_t,\hat{W_{t+1}}) fΘ(Wt,Wt+1^)来获得Wt+1。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
As an AI language model, I am not able to write a script but I can provide you with an overview of what you need to do to prune YOLOv5 by Python. Pruning is a technique used to reduce the size of deep neural networks by removing the weights with the lowest importance. The idea is to keep the most important weights while discarding the rest, reducing the number of parameters and making the model more efficient. Here are the steps you can follow to prune YOLOv5 by Python: 1. Load the pre-trained YOLOv5 model using PyTorch. 2. Define a pruning method, such as magnitude pruning, that will determine which weights to keep and which to discard. 3. Define a pruning scheduler that will determine when to prune the model, for example, after every epoch or after a certain number of iterations. 4. Train the YOLOv5 model on your dataset. 5. After each pruning iteration, retrain the model to fine-tune the remaining weights and improve its accuracy. 6. Repeat steps 3-5 until the desired level of pruning is achieved. To implement these steps, you can use PyTorch's pruning module, which provides functions for different pruning methods and schedulers. You can also refer to the PyTorch documentation and examples for more information on how to implement pruning in your YOLOv5 model. Note that pruning can significantly reduce the size of your model, but it may also affect its accuracy. Therefore, it's important to carefully select the pruning method and schedule and evaluate the performance of the pruned model on your validation set.

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值