Anomaly-Transformer (ICLR 2022 Spotlight)复现过程及问题

官方代码:GitHub - thuml/Anomaly-Transformer: About Code release for "Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight), https://openreview.net/forum?id=LzQQ89U1qm_

作者推荐的是 python3.6,pytorch 1.4 

1. 环境修改

尝试安装 pytorch 1.4 运行,但是代码会卡住,并且没有报错。定位错误在:Anomaly-Transformer/model/attn.py

self.distances = torch.zeros((window_size, window_size)).cuda()

.cuda() 卡住:原因是 安装的 pytorch 1.4 对应的CUDA 版本为 10.x,算力是 sm_86,CUDA 10.x 最高支持到 sm_75,因此需要CUDA 11.x来支持sm_8.x


因此升级 我的环境 python3.7,  pytorch 1.12  , 显卡3080Ti, CUDA 版本:11.3

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

再次运行训练脚本,又报错:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512, 25]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

解决:注释掉Anomaly-Transformer/solver.py 的第一个 .step():

                # Minimax strategy
                loss1.backward(retain_graph=True)
                # self.optimizer.step()
                loss2.backward()
                self.optimizer.step()

参考:Why the optimizer.step() write twice? · Issue #8 · thuml/Anomaly-Transformer · GitHub


2.  恭喜!  成功运行!

python main.py --anormly_ratio 1 --num_epochs 3    --batch_size 128  --mode train --dataset PSM  --data_path dataset/PSM --input_c 25    --output_c 25

------------ Options -------------
anormly_ratio: 1.0
batch_size: 128
data_path: dataset/PSM
dataset: PSM
input_c: 25
k: 3
lr: 0.0001
mode: train
model_save_path: checkpoints
num_epochs: 3
output_c: 25
pretrained_model: None
win_size: 100 

======================TEST MODE======================
/opt/conda/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
  warnings.warn(warning.format(ret))
Threshold : 0.002150955616962149
pred:    (87800,)
gt:      (87800,)
pred:  (87800,)
gt:    (87800,)
Accuracy : 0.9848, Precision : 0.9713, Recall : 0.9739, F-score : 0.9726 

论文中的结果:对于PSM数据集

P: 96.91,R: 98.9,  F1: 97.89

复现的 Recall 略低。但是 Precision 略高。二者本就是需要权衡。可以通过调整上面的 Threshold : 0.002150955616962149 平衡二者。

  • 4
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 41
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 41
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

理心炼丹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值