一篇论文复现的整体思路和复现记录(三,基础实现篇)

工欲善其事,必先利其器,要写出好的MATLAB代码,先从最基础的代码开始做起。

1 ADMM原理

以下内容整理自Standford University的Boyd老师的课件论文
ADMM问题的基本形式:
最优化问题形式包括两组可分离自变量和线性等式约束:
min ⁡ x , z f ( x ) + g ( z ) s . t . A x + B z = c \begin{align*} \min_{x,z}\quad&f(\mathbf{x})+g(\textbf{z})\\ s.t. \quad&\mathbf{Ax}+\mathbf{Bz} = \mathbf{c} \end{align*} x,zmins.t.f(x)+g(z)Ax+Bz=c
写出该问题对应的拉格朗日函数式:
L ρ ( x , z , λ ) = f ( x ) + g ( z ) + y T ( A x + B z − c ) + ρ 2 ∥ A x + B z − c ∥ 2 2 L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\mathbf{y}^T(\mathbf{Ax}+\mathbf{Bz} - \mathbf{c})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}\Vert^2_2} Lρ(x,z,λ)=f(x)+g(z)+yT(Ax+Bzc)+2ρAx+Bzc22
按如下步骤,按照Gauss-Seidel方法更新迭代:
x k + 1 = a r g m i n x L ρ ( x , z k , y k ) z k + 1 = a r g m i n z L ρ ( x k + 1 , z , y k ) y k + 1 = y k + ρ ( A x k + 1 + B z k + 1 − c ) \begin{align*} x^{k+1} &= \mathop{argmin}\limits_x L_{\rho}(\mathbf{x}, \mathbf{z}^k, \mathbf{y}^k)\\ z^{k+1} &= \mathop{argmin}\limits_z L_{\rho}(\mathbf{x}^{k+1}, \mathbf{z}, \mathbf{y}^k)\\ y^{k+1} &= y^k+\rho(\mathbf{Ax}^{k+1}+\mathbf{Bz}^{k+1}-c) \end{align*} xk+1zk+1yk+1=xargminLρ(x,zk,yk)=zargminLρ(xk+1,z,yk)=yk+ρ(Axk+1+Bzk+1c)
进一步,如在拉格朗日函数中定义 u k = ( 1 / ρ ) y k \mathbf{u}^k = (1/\rho)\mathbf{y}^k uk=(1/ρ)yk,放缩对偶变量,拉格朗日函数变为:
L ρ ( x , z , λ ) = f ( x ) + g ( z ) + ρ 2 ∥ A x + B z − c + u ∥ 2 2 + c o n s t L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}\Vert^2_2+const} Lρ(x,z,λ)=f(x)+g(z)+2ρAx+Bzc+u22+const
相对应的迭代式子变为:
x k + 1 = a r g m i n x f ( x ) + ρ 2 ∥ A x + B z k − c + u k ∥ 2 2 z k + 1 = a r g m i n z g ( z ) + ρ 2 ∥ A x k + 1 + B z − c + u k ∥ 2 2 u k + 1 = u k + A x k + 1 + B z k + 1 − c \begin{align*} \mathbf{x}^{k+1} &= \mathop{argmin}\limits_x f(\mathbf{x})+ \frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz}^k - \mathbf{c}+\mathbf{u}^k\Vert^2_2}\\ \mathbf{z}^{k+1} &= \mathop{argmin}\limits_z g(\mathbf{z})+ \frac{\rho}{2}{\Vert\mathbf{Ax}^{k+1}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}^k\Vert^2_2}\\ \mathbf{u}^{k+1} &= \mathbf{u}^k+\mathbf{Ax}^{k+1}+\mathbf{Bz}^{k+1}-c \end{align*} xk+1zk+1uk+1=xargminf(x)+2ρAx+Bzkc+uk22=zargming(z)+2ρAxk+1+Bzc+uk22=uk+Axk+1+Bzk+1c
到此为止,上述仍然是一些抽象的概念,与实际问题并无什么联系。好在Boyd老师在网站上面挂出来了一些案例,但是这些案例难懂,我费了一些功夫终于看明白其中一二。

2 部分特殊表示

有些符号看得不是很清楚,总结于以下

邻近算子( Proximity Operator \text{Proximity Operator} Proximity Operator

前面已经知道 x x x的更新值为
x + = a r g m i n x f ( x ) + ρ 2 ∥ A x + B z k − c + u k ∥ 2 2 {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}+{Bz}^k - {c}+{u}^k\Vert^2_2} x+=xargminf(x)+2ρAx+Bzkc+uk22
v = − B z + c − u v = -Bz+c-u v=Bz+cu
x + = a r g m i n x f ( x ) + ρ 2 ∥ A x − v ∥ 2 2 {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}-v\Vert^2_2} x+=xargminf(x)+2ρAxv22
又令 A = I A=I A=I
x + = a r g m i n x f ( x ) + ρ 2 ∥ x − v ∥ 2 2 (*) \begin{equation} {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}\end{equation}\tag{*} x+=xargminf(x)+2ρxv22(*)
该式子的右端项可以用 prox f , ρ ( v ) \textbf{prox}_{f,\rho}(v) proxf,ρ(v)来表示,我暂时将其翻译为:函数 f f f带惩罚项 ρ \rho ρ的邻近算子。所以,后面有一些式子会用式子(*)来简略表示。

向集合的投影( Projection \text{Projection} Projection

当函数 f f f足够简单, x x x的更新值成为前面所说到的邻近算子的形式,可以用解析的办法分析。其中一个例子是,假如说 f f f是一个非空闭凸集的指示函数( indicator function \text{indicator function} indicator function),那么也可以将x的更新值表示为:
x + = a r g m i n x f ( x ) + ρ 2 ∥ x − v ∥ 2 2 = Π C ( v ) {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}=\Pi_\mathcal{C}(v) x+=xargminf(x)+2ρxv22=ΠC(v)
其中, Π C \Pi_\mathcal{C} ΠC表示向 C \mathcal{C} C上的欧式范数的投影.

软阈值( Soft Thresholding \text{Soft Thresholding} Soft Thresholding

感谢前辈的文章12,基本了解软阈值的情况。
软阈值问题的形式为:
S κ ( a ) = { a − κ a > κ 0 ∣ a ∣ ≤ κ a + κ a < − κ = ( a − κ ) + − ( − a − κ ) + S_\kappa(a) =\begin{cases} a-\kappa&a>\kappa\\ 0&|a|\leq\kappa\\ a+\kappa&a<-\kappa \end{cases}=(a-\kappa)_+-(-a-\kappa)_+ Sκ(a)= aκ0a+κa>κaκa<κ=(aκ)+(aκ)+
是在求解优化问题形如:
a r g m i n x ∥ x − B ∥ 2 2 + λ ∥ x ∥ 1 \mathop{argmin}\limits_x\Vert x-B\Vert^2_2+\lambda\Vert x\Vert_1 xargminxB22+λx1
的时候作用的,用于标记 x x x更新值的取值。其中 B = [ b 1 , b 2 , … , b n ] B = [b_1, b_2, \dots, b_n] B=[b1,b2,,bn]。由于 ∥ x ∥ 1 \Vert x\Vert_1 x1并不可微,所以通过分类讨论的结果:
S λ / 2 ( b ) = { b − λ 2 b > λ 2 0 ∣ b ∣ < λ 2 b + λ 2 b < − λ 2 = ( b − λ 2 ) + − ( − b − λ 2 ) + S_{\lambda/2}(b) =\begin{cases} b-\frac{\lambda}{2}&b>\frac{\lambda}{2}\\ 0&|b|<\frac{\lambda}{2}\\ b+\frac{\lambda}{2}&b<-\frac{\lambda}{2} \end{cases}=(b-\frac{\lambda}{2})_+-(-b-\frac{\lambda}{2})_+ Sλ/2(b)= b2λ0b+2λb>2λb<2λb<2λ=(b2λ)+(b2λ)+

例如,在Boyd老师的论文中就提到, x i x_i xi更新值:
x i + = a r g m i n x i ( λ ∣ x i ∣ + ( ρ / 2 ) ( x i − v i ) 2 ) . x_i^+=\mathop{argmin}\limits_{x_i}(\lambda|x_i|+(\rho/2)(x_i-v_i)^2). xi+=xiargmin(λxi+(ρ/2)(xivi)2).
可以得到相应更新值:
x i + : = S λ / ρ ( v i ) x_i^+:=S_{\lambda/\rho}(v_i) xi+:=Sλ/ρ(vi)

3 部分代码解析

以下选择Lasso问题的代码进行分析。代码摘自网站,此处只是复刻编程的流程,整理一下思路。

3.1 代码架构

代码分为lasso.m, objective.m, shrinkage.m, factor.m等,在运行中实际起到作用分别是:

  1. function lasso.m是整个ADMM的算法执行流程,包括数据记录,数据迭代,数据输出,迭代终点判断等内容;
  2. objective.m是整个优化函数的目标函数;
  3. shrinkage.m表示整个优化函数的目标函数;
  4. factor.m则是根据A的形状的不同,进行的分解。
  5. 通过实际案例检验所写的代码是否正确反映了算法。

3.2 代码解构

3.2.1 lasso.m

lasso问题的目标函数为:
m i n i m i z e 1 2 ∥ A x − b ∥ 2 2 + λ ∥ x ∥ 1 minimize \quad \frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert x\Vert_1 minimize21Axb22+λx1
写成ADMM方法所能够求解的格式:
m i n i m i z e 1 2 ∥ A x − b ∥ 2 2 + λ ∥ z ∥ 1 subject to x − z = 0 \begin{align*} minimize\quad &\frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert z\Vert_1\\ \text{subject to}\quad &x-z = 0 \end{align*} minimizesubject to21Axb22+λz1xz=0
迭代的表达式为:
x k + 1 : = ( A T A + ρ I ) − 1 ( A T b + ρ ( z k − u k ) ) z k + 1 : = S λ / ρ ( x k + 1 + u k ) u k + 1 : = u k + x k + 1 − z k + 1 \begin{align*} x^{k+1}&:=(A^TA+\rho I)^{-1}(A^Tb+\rho(z^k-u^k))\\ z^{k+1}&:=S_{\lambda/\rho}(x^{k+1}+u^k)\\ u^{k+1}&:=u^k+x^{k+1}-z^{k+1} \end{align*} xk+1zk+1uk+1:=(ATA+ρI)1(ATb+ρ(zkuk)):=Sλ/ρ(xk+1+uk):=uk+xk+1zk+1

function [z, history] = lasso(A, b, lambda, rho, alpha)
% lasso  Solve lasso problem via ADMM
% [z, history] = lasso(A, b, lambda, rho, alpha);
% Solves the following problem via ADMM:
%   minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1
% The solution is returned in the vector x.
% history is a structure that contains the objective value, the primal and
% dual residual norms, and the tolerances for the primal and dual residual
% norms at each iteration.
% rho is the augmented Lagrangian parameter.
% alpha is the over-relaxation parameter (typical values for alpha are
% between 1.0 and 1.8).
% More information can be found in the paper linked at
%:http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html
%

t_start = tic;
Global constants and defaults
QUIET    = 0;
MAX_ITER = 1000;
ABSTOL   = 1e-4; 
RELTOL   = 1e-2;
Data preprocessing
[m, n] = size(A);

% save a matrix-vector multiply
Atb = A'*b;
ADMM solver
x = zeros(n,1);
z = zeros(n,1);
u = zeros(n,1);

% cache the factorization
[L U] = factor(A, rho);

if ~QUIET
    fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
      'r norm', 'eps pri', 's norm', 'eps dual', 'objective');
end

for k = 1:MAX_ITER

    % x-update
    q = Atb + rho*(z - u);    % temporary value
    if( m >= n )    % if skinny
       x = U \ (L \ q);
    else            % if fat
       x = q/rho - (A'*(U \ ( L \ (A*q) )))/rho^2;
    end

    % z-update with relaxation
    zold = z;
    x_hat = alpha*x + (1 - alpha)*zold;
    z = shrinkage(x_hat + u, lambda/rho);

    % u-update
    u = u + (x_hat - z);

    % diagnostics, reporting, termination checks
    history.objval(k)  = objective(A, b, lambda, x, z);

    history.r_norm(k)  = norm(x - z);
    history.s_norm(k)  = norm(-rho*(z - zold));

    history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z));
    history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*norm(rho*u);

    if ~QUIET
        fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ...
            history.r_norm(k), history.eps_pri(k), ...
            history.s_norm(k), history.eps_dual(k), history.objval(k));
    end

    if (history.r_norm(k) < history.eps_pri(k) && ...
       history.s_norm(k) < history.eps_dual(k))
         break;
    end

end

if ~QUIET
    toc(t_start);
end
end

3.2.2 objective.m

function p = objective(A, b, lambda, x, z)
    p = ( 1/2*sum((A*x - b).^2) + lambda*norm(z,1) );
end

3.2.3 shrinkage.m

此处表达的是式子:
z k + 1 : = S λ / ρ ( x k + 1 + u k ) = ( x k + 1 + u k − λ / ρ ) + − ( − x k + 1 − u k − λ / ρ ) z^{k+1}:=S_{\lambda/\rho}(x^{k+1}+u^k)=(x^{k+1}+u^k-\lambda/\rho)_+-(-x^{k+1}-u^k-\lambda/\rho) zk+1:=Sλ/ρ(xk+1+uk)=(xk+1+ukλ/ρ)+(xk+1ukλ/ρ)

function z = shrinkage(x, kappa)
    z = max( 0, x - kappa ) - max( 0, -x - kappa );
end

3.2.4 factor.m

A A A是一个 m × n m\times n m×n的矩阵。当 m < n m<n m<n,由于在 x x x的更新式中, A T A + ρ I A^TA+\rho I ATA+ρI是一个 n × n n\times n n×n的矩阵,而提一个 1 / ρ 1/\rho 1/ρ后,变为 I + ( 1 / ρ ) A A T I+(1/\rho)AA^T I+(1/ρ)AAT,是一个 m × m m\times m m×m大小的矩阵。由于矩阵求逆的复杂度为 O ( n 3 ) O(n^3) O(n3),因此,后面这种做法更便于求解。结合稀疏矢量技术完成文中代码。

function [L U] = factor(A, rho)
    [m, n] = size(A);
    if ( m >= n )    % if skinny
       L = chol( A'*A + rho*speye(n), 'lower' );
    else            % if fat
       L = chol( speye(m) + 1/rho*(A*A'), 'lower' );
    end

    % force matlab to recognize the upper / lower triangular structure
    L = sparse(L);
    U = sparse(L');
end

3.2.5 难点和实施

难点主要在于要矩阵表达式的运算和推导,容易出错,因此给复现带来了困难。实际案例(Example)见于链接,为上面函数的应用。

<think>好的,我需要帮助用户找到关于文本分类论文复现方法或代码实现,特别是使用PyTorch或TensorFlow的。首先,我需要整理用户提供的引用内容,看看是否有相关的信息。 根据引用[4],用户已经在SST-2MR数据集上测试了复现效果,使用了一些模型如CNN-static、CNN-rand、CNN-fine-tunedBert-base-cased,但结果与原论文有差距。这说明用户可能在复现过程中遇到了困难,需要更详细的实现方法或代码参考。 引用[2]提到用户将TensorFlow代码转为PyTorch时遇到了挑战,耗时一个月,这表明用户可能需要一些转换技巧或工具。引用[3]建议检查模型是否有现有的实现,可以查阅相关表格,这可能涉及到Hugging Face的库,比如Transformers,里面有很多预训练模型的实现。 引用[1]对比了PyTorchTensorFlow在Reddit上的讨论,可能涉及两者的优缺点,但用户的主要需求是复现方法,所以可能需要推荐两者中的最佳实践。 接下来,我应该从以下几个方面回答用户的问题: 1. **论文复现的基本步骤**:帮助用户系统性地了解复现流程,比如理解论文、选择框架、数据预处理等。 2. **PyTorchTensorFlow的代码实现资源**:提供具体的代码库或示例,如Hugging Face Transformers、TensorFlow官方模型等。 3. **模型转换技巧**:根据用户将TF转PyTorch的经验,给出转换中的常见问题解决方法,比如参数名称对齐、动态图与静态图的差异。 4. **调试与优化建议**:针对引用[4]中复现结果不理想的情况,提供调试策略,如检查超参数、数据增强、模型结构等。 5. **相关工具社区资源**:推荐有用的工具如ONNX、模型转换脚本,以及论坛如GitHub、Stack Overflow等。 需要确保回答结构清晰,分点说明,并且引用提供的参考文献来支持建议。同时,生成的相关问题要围绕用户可能的后续疑问,比如复现技巧、框架对比、提升复现效果的方法等。</think>### 文本分类论文复现方法与代码实现 针对文本分类论文复现,需结合论文细节、框架选择(如PyTorch或TensorFlow)代码实践。以下是关键步骤与资源推荐: --- #### 1. **复现基本步骤** - **理解论文核心**:明确模型结构(如CNN、RNN、BERT)、数据处理方法(如词嵌入、分词方式)训练策略(如学习率、正则化)。 - **选择框架**:根据团队熟悉度选择PyTorch(动态图调试灵活)或TensorFlow(静态图生产部署友好)[^1][^2]。 - **数据预处理**: - 使用标准化工具(如Hugging Face Tokenizers)处理文本,对齐论文的数据划分增强方法[^3]。 - 示例代码(PyTorch): ```python from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt") ``` - **模型实现**: - 复现模型时,优先参考官方代码(如有)或社区实现(如Hugging Face、TensorFlow Models)。 - 若需从TensorFlow转为PyTorch,需注意层名称对齐(如`Conv1D` vs `nn.Conv1d`)参数初始化差异[^2][^4]。 --- #### 2. **代码资源推荐** - **PyTorch**: - Hugging Face Transformers库:提供BERT、CNN等预训练模型,可直接用于文本分类。 ```python from transformers import BertForSequenceClassification model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) ``` - [Text Classification Examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification):官方示例涵盖训练、评估流程。 - **TensorFlow**: - TensorFlow官方模型库(如[TensorFlow Models](https://github.com/tensorflow/models))中的文本分类实现。 - [TensorFlow Text Classification Tutorial](https://www.tensorflow.org/tutorials/keras/text_classification):基于Keras的实战教程。 --- #### 3. **复现优化与调试** - **超参数调整**:若复现效果低于论文(如引用[4]中SST-2差距1-2%),需检查学习率、批次大小、优化器(如AdamW vs SGD)是否与原文一致。 - **模型细节**: - 词嵌入初始化:CNN-rand(随机初始化)CNN-static(预训练词向量)性能差异显著。 - 微调策略:BERT模型需分层调整学习率,避免全网络粗暴微调。 - **实验记录**:使用工具(如Weights & Biases)记录超参数结果,便于对比分析。 --- #### 4. **工具与社区支持** - **模型转换工具**: - [ONNX](https://onnx.ai/):支持跨框架模型转换(如TF→PyTorch)。 - [torchfy](https://github.com/antonio-f/torchfy):部分TensorFlow代码转PyTorch的实用脚本。 - **社区求助**: - GitHub Issues:在相关仓库中提问(如Hugging Face社区)。 - Stack Overflow:使用标签[pytorch]、[tensorflow]、[text-classification]。 --- ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值