文章及代码地址
Paper: “Graph Neural Network-Based Anomaly Detection in Multivariate Time Series” (AAAI 2021) "
Code: https://github.com/d-ailin/GDN
报错
Traceback (most recent call last):
File "/home/wangq/workshop/GDN/main.py", line 317, in <module>
main.run()
File "/home/wangq/workshop/GDN/main.py", line 126, in run
self.train_log = train(self.model, model_save_path,
File "/home/wangq/workshop/GDN/train.py", line 79, in train
out = model(x, edge_index).float().to(device) # out (64,50)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wangq/workshop/GDN/models/GDN.py", line 170, in forward
gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num * batch_num,
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wangq/workshop/GDN/models/GDN.py", line 69, in forward
out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wangq/workshop/GDN/models/graph_layer.py", line 79, in forward
out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index,
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 480, in propagate
out = self.aggregate(out, **aggr_kwargs)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/conv/message_passing.py", line 604, in aggregate
return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/experimental.py", line 115, in wrapper
return func(*args, **kwargs)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 133, in __call__
raise e
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 125, in __call__
return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/basic.py", line 22, in forward
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/nn/aggr/base.py", line 176, in reduce
return scatter(x, index, dim, dim_size, reduce)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_geometric/utils/scatter.py", line 157, in scatter
return torch_scatter.scatter(src, index, dim, dim_size=dim_size,
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/scatter.py", line 152, in scatter
return scatter_sum(src, index, dim, out, dim_size)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum
index = broadcast(index, src, dim)
File "/home/wangq/anaconda3/envs/exp/lib/python3.9/site-packages/torch_scatter/utils.py", line 12, in broadcast
src = src.expand(other.size())
RuntimeError: The expanded size of the tensor (1) must match the existing size (96000) at non-singleton dimension 1. Target sizes: [96000, 1, 32]. Tensor sizes: [1, 96000, 1]
进程已结束,退出代码1
解决方案
https://github.com/pyg-team/pytorch_geometric/discussions/4839
# 原本的 def message 返回值
return x_j * alpha.view(-1, self.heads, 1)
# 修改后的 def message 返回值
return (x_j * alpha.view(-1, self.heads, 1)).squeeze()