pytorch搭建GCN+LSTM网络

 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

# 定义GCN+LSTM网络
class GCN_LSTM(nn.Module):
    def __init__(self, num_node_features, gcn_hidden_channels, lstm_hidden_channels, lstm_layers, output_dim):
        super(GCN_LSTM, self).__init__()
        self.gcn1 = GCNConv(num_node_features, gcn_hidden_channels)
        self.gcn2 = GCNConv(gcn_hidden_channels, gcn_hidden_channels)
        self.lstm = nn.LSTM(gcn_hidden_channels, lstm_hidden_channels, lstm_layers, batch_first=True)
        self.fc = nn.Linear(lstm_hidden_channels, output_dim)

    def forward(self, x, edge_index, batch):
        # GCN部分
        x = self.gcn1(x, edge_index)
        x = F.relu(x)
        x = self.gcn2(x, edge_index)
        x = F.relu(x)

        # LSTM部分
        # 将图节点特征转化为序列输入
        x = x.unsqueeze(0)
        out, (hn, cn) = self.lstm(x)
        out = out[:, -1, :]  # 取最后一个时间步的输出

        # 全连接层输出
        out = self.fc(out)
        return out

# 示例数据生成
def generate_data():
    # 生成一个随机的图数据
    num_nodes = 10
    num_node_features = 5
    x = torch.rand((num_nodes, num_node_features))  # 节点特征
    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                               [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long)  # 边
    y = torch.tensor([0])  # 图标签

    data = Data(x=x, edge_index=edge_index, y=y)
    return data

# 创建模型、定义损失函数和优化器
num_node_features = 5
gcn_hidden_channels = 16
lstm_hidden_channels = 32
lstm_layers = 1
output_dim = 1
model = GCN_LSTM(num_node_features, gcn_hidden_channels, lstm_hidden_channels, lstm_layers, output_dim)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 生成数据
data = generate_data()
dataloader = DataLoader([data], batch_size=1, shuffle=True)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

print("Training complete.")

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一枚爱吃大蒜的程序员

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值