时间序列GDN 代码报错解决方案

文章及代码地址

Paper: “Graph Neural Network-Based Anomaly Detection in Multivariate Time Series” (AAAI 2021) "
Code: https://github.com/d-ailin/GDN

报错

Traceback (most recent call last):
  File "/home/wangq/workshop/GDN/main.py", line 317, in <module>
    main.run()
  File "/home/wangq/workshop/GDN/main.py", line 126, in run
    self.train_log = train(self.model, model_save_path,
  File "/home/wangq/workshop/GDN/train.py", line 79, in train
    out = model(x, edge_index).float().to(device) # out (64,50)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wangq/workshop/GDN/models/GDN.py", line 170, in forward
    gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num * batch_num,
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wangq/workshop/GDN/models/GDN.py", line 69, in forward
    out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wangq/workshop/GDN/models/graph_layer.py", line 79, in forward
    out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index,
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 480, in propagate
    out = self.aggregate(out, **aggr_kwargs)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 604, in aggregate
    return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/experimental.py", line 115, in wrapper
    return func(*args, **kwargs)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 133, in __call__
    raise e
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 125, in __call__
    return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/basic.py", line 22, in forward
    return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 176, in reduce
    return scatter(x, index, dim, dim_size, reduce)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/utils/scatter.py", line 157, in scatter
    return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/scatter.py", line 152, in scatter
    return scatter_sum(src, index, dim, out, dim_size)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum
    index = broadcast(index, src, dim)
  File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/utils.py", line 12, in broadcast
    src = src.expand(other.size())
RuntimeError: The expanded size of the tensor (1) must match the existing size (96000) at non-singleton dimension 1.  Target sizes: [96000, 1, 32].  Tensor sizes: [1, 96000, 1]

进程已结束,退出代码1

解决方案

https://github.com/pyg-team/pytorch_geometric/discussions/4839

# 原本的 def message 返回值
return x_j * alpha.view(-1, self.heads, 1)
# 修改后的 def message 返回值
return (x_j * alpha.view(-1, self.heads, 1)).squeeze()
  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值