“class GATNet(nn.Module):
def __init__(self, node_num=83, time_steps=14):
super().__init__()
self.node_num = node_num
self.time_steps = time_steps
# 修正后的GATv2配置
self.gat = GATv2Conv(
in_channels=1,
out_channels=32,
heads=3,
add_self_loops=True,
)
# 时间特征提取
self.temporal = nn.GRU(
input_size=32 * 3,
hidden_size=64,
num_layers=1,
batch_first=True
)
# 特征融合
self.fusion = nn.Sequential(
nn.Linear(64 + 32 * 3, 128),
nn.LayerNorm(128),
nn.ReLU()
)
# 输出层
self.output = nn.Sequential(
nn.Linear(128, 64),
nn.Dropout(0.2),
nn.Linear(64, 1)
)
def forward(self, x, adj_matrix):
"""
输入:
x: [batch_size, time_steps, node_num] (64,14,83)
adj_matrix: [node_num, node_num] (83,83)
"""
device = next(self.parameters()).device
batch_size = x.size(0)
# 邻接矩阵预处理
adj = adj_matrix.to(device)
edge_index, edge_attr = dense_to_sparse(adj)
edge_attr = edge_attr.unsqueeze(-1).to(device)
# 空间特征提取
spatial_features = []
for t in range(self.time_steps):
x_t = x[:, t].unsqueeze(-1).to(device) # [64,83,1]
x_flat = x_t.reshape(-1, 1)
gat_out = self.gat(x_flat, edge_index, edge_attr)
spatial_features.append(gat_out.view(batch_size, self.node_num, -1))
# 时间特征提取
temporal_input = torch.stack(spatial_features, dim=1) # [64,14,83,96]
temporal_out, _ = self.temporal(temporal_input.view(batch_size * self.node_num, self.time_steps, -1))
# 特征融合
last_spatial = spatial_features[-1]
last_temporal = temporal_out[:, -1].view(batch_size, self.node_num, -1)
fused = torch.cat([last_spatial, last_temporal], dim=-1)
fused = self.fusion(fused)
output = self.output(fused).squeeze(-1)
return output.unsqueeze(1)”这段代码报错:“gat_out = self.gat(x_flat, edge_index, edge_attr)
File "E:\deeplearning\project\BiLSTM-CRF-NER-main\venv\lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "E:\deeplearning\project\BiLSTM-CRF-NER-main\venv\lib\site-packages\torch_geometric\nn\conv\gatv2_conv.py", line 149, in forward
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
File "E:\deeplearning\project\BiLSTM-CRF-NER-main\venv\lib\site-packages\torch_geometric\utils\loop.py", line 84, in add_self_loops
loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
TypeError: arange() received an invalid combination of arguments - got (int, Tensor, device=torch.device, dtype=torch.dtype), but expected one of:
* (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
*”请帮我修改一下
最新发布