使用RGNN训练和测试EEG公开的SEED数据集

下面所有博客是个人对EEG脑电的探索,项目代码是早期版本不完整,需要完整项目代码和资料请私聊。


主要内容:
1、在EEG(脑电)项目中,使用图神经网络对脑电进行处理,具体包括baseline的GCN图架构、复现baseline论文的RGNN架构、注意力机制图架构、Transformer图架构、注重效率的simple图架构等,进行实验和对比。
2、学习图神经网络相关的资料。是学习图神经网络的一个完整项目;



数据集
1、脑电项目探索和实现(EEG) (上):研究数据集选取和介绍SEED
相关论文阅读分析:
1、EEG-SEED数据集作者的—基线论文阅读和分析
2、图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》
3、EEG-GNN论文阅读和分析:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》
4、论文阅读和分析:Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification
5、论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》
6、论文阅读和分析: “How Attentive are Graph Attention Networks?”
7、论文阅读和分析:Simplifying Graph Convolutional Networks

8、论文阅读和分析:LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation
9、图神经网络汇总和总结
相关实验和代码实现:
1、用于图神经网络的脑电数据处理实现_图神经网络 脑电
2、使用GCN训练和测试EEG的公开SEED数据集
3、使用GAT训练和测试EEG公开的SEED数据集
4、使用SGC训练和测试SEED数据集
5、使用Transformer训练和测试EEG的公开SEED数据集_eeg transformer
6、使用RGNN训练和测试EEG公开的SEED数据集
辅助学习资料:
1、官网三个简单Graph示例说明三种层次的应用_graph 简单示例
2、PPI数据集示例项目学习图神经网络
3、geometric库的数据处理详解
4、NetworkX的dicts of dicts以及解决Seven Bridges of Königsberg问题
5、geometric源码阅读和分析:MessagePassin类详解和使用
6、cora数据集示例项目学习图神经网络
7、Graph 聚合
8、QM9数据集示例项目学习图神经网络
9、处理图的开源库

部分代码如下:

复现论文:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》

论文阅读参考:[图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》_KPer_Yang的博客-CSDN博客]

注:参考论文作者开源的代码,但是作者的代码有很多模糊不清的地方和错误的地方,本代码予以补充和更正;具体如下:

(1)报错,inv_mask = 1- mask不合法,更正为:inv_mask = ~mask

(2)未给出edge_weight权重的计算方式,论文中使用设备的物理距离,但是代码未给出,由于edge_weight可学习,更正为等权重或者不使用权重大的方式(默认等权重);

(3)将卷积con1单卷积层更改成多层卷积提取特征,提高准确率,self.conv_s = torch.nn.ModuleList();

(4)最后的线性层报错,矩阵不匹配,更正成self.fc = nn.Linear(hidden_channels, out_channels)

(5)去掉 self.domain_classifier,这个在算法中没有作用,反而增加代码复杂度;

(6)论文作者没有训练、测试部分的代码,增加这部分;

(7)将写法统一成自己的风格,更加规范化;

# -*- coding: utf-8 -*-
#
# Copyright (C) 2022 Emperor_Yang, Inc. All Rights Reserved 
#
# @CreateTime    : 2023/2/20 21:51
# @Author        : Emperor_Yang 
# @File          : ECG_RGNN.py
# @Software      : PyCharm


from abc import ABC, ABCMeta
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch_geometric.nn import SGConv, global_add_pool
from torch_scatter import scatter_add
from easydict import EasyDict
from torch_geometric.data import DataLoader
from data_process.seed_loader_gnn_memory import SeedGnnMemoryDataset

config = EasyDict()
config.learn_rate = 0.01
config.epoch = 20
config.note_feature_dim = 5
config.note_num = 62
config.hidden_channels = 16
config.class_num = 3
config.hidden_layers = 2
config.batch_size = 16
config.max_loss_increase_time = 3
config.learn_edge_weight = True
config.K = 5


def maybe_num_nodes(index, num_nodes=None):
    return index.max().item() + 1 if num_nodes is None else num_nodes


def add_remaining_self_loops(edge_index,
                             edge_weight=None,
                             fill_value=1,
                             num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    row, col = edge_index

    mask = row != col
    inv_mask = ~mask
    loop_weight = torch.full(
        (num_nodes,),
        fill_value,
        dtype=None if edge_weight is None else edge_weight.dtype,
        device=edge_index.device)

    if edge_weight is not None:
        assert edge_weight.numel() == edge_index.size(1)
        remaining_edge_weight = edge_weight[inv_mask]
        if remaining_edge_weight.numel() > 0:
            loop_weight[row[inv_mask]] = remaining_edge_weight
        edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)

    loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
    loop_index = loop_index.unsqueeze(0).repeat(1, 2)
    edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)

    return edge_index, edge_weight


class RSGConv(SGConv, metaclass=ABCMeta):
    def __init__(self, num_features, out_channels, K=1, cached=False, bias=True):
        super(RSGConv, self).__init__(num_features, out_channels, K=K, cached=cached, bias=bias)

    # allow negative edge weights
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),),
                                     dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)
        row, col = edge_index
        deg = scatter_add(torch.abs(edge_weight), row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        if not self.cached or self.cached_result is None:
            edge_index, norm = RSGConv.norm(
                edge_index, x.size(0), edge_weight, dtype=x.dtype)

            for k in range(self.K):
                x = self.propagate(edge_index, x=x, norm=norm)
        return self.lin(x)

    def message(self, x_j, norm):
        # x_j: (batch_size*num_nodes*num_nodes, num_features)
        # norm: (batch_size*num_nodes*num_nodes, )
        return norm.view(1, -1) * x_j


class SymSimGCNNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 edge_weight=None, dropout=0.5):
        """
            edge_weight: initial edge matrix
            dropout: dropout rate in final linear layer
        """
        super(SymSimGCNNet, self).__init__()
        self.num_nodes = config.note_num
        self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0)
        if edge_weight is not None:
            edge_weight = edge_weight.reshape(self.num_nodes, self.num_nodes)[
                self.xs, self.ys]  # strict lower triangular values
            self.edge_weight = nn.Parameter(edge_weight, requires_grad=config.learn_edge_weight)
        else:
            self.edge_weight = None
        self.dropout = dropout

        self.conv_s = torch.nn.ModuleList()
        self.conv_s.append(RSGConv(num_features=in_channels, out_channels=hidden_channels, K=config.K))
        for i in range(config.hidden_layers - 1):
            self.conv_s.append(RSGConv(num_features=hidden_channels, out_channels=hidden_channels, K=config.K))

        self.fc = nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        batch_size = len(data.y)
        x, edge_index = data.x, data.edge_index
        edge_weight = None
        if self.edge_weight is not None:
            edge_weight = torch.zeros((self.num_nodes, self.num_nodes), device=edge_index.device)
            edge_weight[self.xs.to(edge_weight.device), self.ys.to(edge_weight.device)] = self.edge_weight
            edge_weight = edge_weight + edge_weight.transpose(1, 0) - torch.diag(
                edge_weight.diagonal())  # copy values from lower tri to upper tri
            edge_weight = edge_weight.reshape(-1).repeat(batch_size)

        for conv in self.conv_s:
            x = conv(x, edge_index, edge_weight).relu()
        x = global_add_pool(x, data.batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc(x)
        return x


model = SymSimGCNNet(config.note_feature_dim, config.hidden_channels, config.class_num)
data_set = SeedGnnMemoryDataset(root='../data/SEED/', processed_file='1_20131027.pt')
train_data_set = data_set[: int(0.8 * data_set.len())]
test_data_set = data_set[int(0.8 * data_set.len()):]
train_data_loader = DataLoader(train_data_set, batch_size=config.batch_size, shuffle=True)
test_data_loader = DataLoader(test_data_set, batch_size=config.batch_size, shuffle=False)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learn_rate)
criterion = torch.nn.CrossEntropyLoss()


def train():
    loss_sum = 0
    data_size = 0
    for mini_batch in train_data_loader:
        if mini_batch.num_graphs == config.batch_size:
            data_size += mini_batch.num_graphs
            model.train()
            optimizer.zero_grad()
            out = model(mini_batch)
            loss = criterion(out, mini_batch.y)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item() / mini_batch.num_graphs
    return loss_sum / data_size


def test():
    count = 0
    data_size = 0
    for mini_batch in test_data_loader:
        if mini_batch.num_graphs == config.batch_size:
            out = model(mini_batch)
            predict = torch.argmax(out, dim=1)
            count += int(predict.eq(mini_batch.y).sum())
            data_size += mini_batch.num_graphs
    print("Test Accuracy:{}%".format(count / data_size * 100))


if __name__ == '__main__':
    loss_increase_time = 0
    last_lost = 1
    for epoch in range(config.epoch):
        avg_loss = train()
        print("epoch:{}, loss:{}".format(epoch+1, avg_loss))
        if avg_loss > last_lost:
            loss_increase_time += 1
        else:
            last_lost = avg_loss
        # 如果连续增加loss大于config.max_loss_increase_time,则停止训练
        if loss_increase_time > config.max_loss_increase_time:
            break
    test()

参考:

1、EEG-GNN论文阅读和分析:《EEG Emotion Recognition Using Dynamical Graph Convolutional Neural Networks》_KPer_Yang的博客-CSDN博客

2、图神经网络EEG论文阅读和分析:《EEG-Based Emotion Recognition Using Regularized Graph Neural Networks》_KPer_Yang的博客-CSDN博客

3、https://github.com/zhongpeixiang/RGNN

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KPer_Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值