20240303学习日志:实战分析学习开源项目Sleepgpt

1项目地址

项目地址:GitHub - yuty2009/sleepgpt: Code for our paper SleepGPT: A Sleep Language Model for Sleep Staging and Sleep Disorder Diagnosis

项目目录:

2项目介绍

2.1GPT

项目名称"SleepGPT"中的"GPT"代表"Generative Pre-trained Transformer",这是一类基于深度学习的自然语言处理技术。与ChatGPT相比,虽然两者都可能基于相似的技术架构,但SleepGPT专注于睡眠数据分析,特别是用于睡眠阶段分析和睡眠障碍诊断的应用,而ChatGPT主要用于文本生成和对话应用。具体到SleepGPT项目,未能直接找到明确说明其使用GPT技术的细节,但基于名称,可以推测该项目利用了类似的预训练模型技术进行睡眠数据的分析和处理。

Generative Pre-trained Transformer (GPT) 模型是一种先进的自然语言处理技术,特别在于其预训练和微调的机制。GPT通过在大量文本数据上进行预训练,学习语言的深层次结构和语境之间的关系。之后,它可以通过在特定任务上的微调来优化性能。这种模型能够生成连贯的文本、理解语境、回答问题、翻译语言等,显示出对语言深刻的理解能力。

2.2readme解读

 这个README文件是关于SleepGPT项目的概述。SleepGPT是一个用于睡眠分期和睡眠障碍诊断的睡眠语言模型。项目使用了层叠的变换器解码器模块来处理一整夜的睡眠阶段序列。每一块输入数据用于预测目标阶段块。整个模型结构包括一个本地特征提取器(即SleepGPT模型)、一个全局特征提取的变换器编码器以及一个用于诊断的分类头。底部的图表展示了一个整夜睡眠序列的例子,它被划分成了非重叠的片段,这些片段随后被输入到分层变换器网络中进行模型处理。

上面的图展示的是SleepGPT模型的概述。这是一个用于睡眠分期和睡眠障碍诊断的语言模型架构。它包含了一系列的Transformer解码器块,处理来自连夜睡眠阶段序列的重叠块。每个输入块通过模型预测目标块。模型的主要组件包括:

  • Layer Norm: 层归一化,用于稳定训练过程。
  • Masked Multi-Self-Attention: 被掩蔽的多头自注意力机制,可以处理序列内的各个元素之间的关系。
  • Feed Forward: 前馈网络,模型的一部分,用于处理注意力机制的输出。

模型的输入部分显示了睡眠阶段序列的示例,其中"N1"、"N2"等代表不同的睡眠阶段。这些阶段作为输入,模型将预测下一个阶段,这对于理解睡眠模式和诊断睡眠问题很有帮助。

 2.2.1Transformer解码器块

在Transformer模型中,解码器块是负责生成输出序列的部分。每个解码器块通常包含以下几个主要组件:

  • 多头自注意力机制(Multi-Head Self-Attention):能够让模型同时关注输入序列中的不同位置,这对于理解序列中各元素之间复杂的关系是非常重要的。
  • 层归一化(Layer Normalization):用于稳定深度网络的学习过程,通常在自注意力和前馈网络的输出之后进行。
  • 前馈网络(Feed Forward Network):通常是两层的全连接网络,它跟在每个自注意力层之后,用来进一步处理数据。

在Transformer的架构中,解码器块通常会串联多个,每个块的输出会成为下一个块的输入。通过这种方式,模型能够逐步构建出一个复杂的输出序列,每个步骤都依赖于之前的输出。在SleepGPT这个项目中,这些解码器块被用来处理和预测睡眠数据,实现睡眠阶段的分类和诊断。

在这张图中,解码器块指的是从"Stages & Position Embedded"开始到"Next Stage Prediction"结束的整个流程。这个块首先将输入的睡眠阶段(如W, N1, N2等)嵌入到一个连续的向量空间中,并结合位置信息,这是由"Stages & Position Embedded"所表示的。然后,解码器通过"Masked Multi Self-Attention"层来处理序列,允许模型在生成当前输出时考虑到之前的所有输入。接着是第一个"Layer Norm"层,用于规范化处理过的注意力输出。随后是"Feed Forward"层,它进一步处理数据并生成下一个阶段的预测。最后是"Next Stage Prediction",这是模型输出预测的睡眠阶段的部分。

2.2.2 LayerNorm层

Layer Normalization (LayerNorm) 是一种在神经网络中常用的规范化技术,主要目的是稳定深度网络的学习过程。它通过对每个样本的所有激活输出进行规范化处理,减少了不同初始化和批数据规模对训练过程的影响。

具体来说,LayerNorm 是这样实现的:

  1. 对于给定的输入,计算其均值和方差。
  2. 使用这些统计数据来规范化输入值,确保输出值有着零均值和单位方差。
  3. 最后,对规范化的值进行缩放和平移,这两个参数是模型学习得到的。

LayerNorm 通常在多头自注意力层和前馈网络的输出之后应用,在Transformer模型中起到了非常关键的作用。

0均值(或零均值)意味着数据经过规范化处理后,它的平均值变为0。单位方差则意味着数据的方差调整为1。这种处理使得数据在不同的特征或者层中具有相似的分布,从而有助于神经网络更快地学习和收敛。在LayerNorm应用中,这通常是通过从每个元素中减去均值并除以标准差(方差的平方根),然后乘以一个学习率的缩放因子(gamma)并加上一个学习率的位移因子(beta),完成规范化处理。

进行这样的规范化处理的目的是为了保证网络中各层的输入分布保持一致,这可以帮助改善网络的训练速度和稳定性,减少训练过程中的内部协变量偏移(internal covariate shift),即保证网络每层输入的分布稳定,不会因为上层参数的更新而产生大的波动。这样有助于加速收敛,还可以允许使用更高的学习率,使得训练过程更加稳定。

2.2.3内部协变量偏移

内部协变量偏移(Internal Covariate Shift)是指在深度网络的训练过程中,每一层的输入分布随着上一层参数的更新而改变的现象。由于每层的输入分布在训练过程中不断变化,网络需要不断适应这些变化,这会减慢训练速度,增加获取模型稳定状态的难度。规范化技术,如批量规范化(Batch Normalization)和层规范化(Layer Normalization),通过规范化层的激活输出来减少这种内部分布的变化,从而有助于加快训练过程并提高模型的稳定性。

区分一下规范化和激活函数:

Sigmoid激活函数并不是为了解决内部协变量偏移问题而设计的。Sigmoid激活函数主要是将输入压缩到(0,1)的范围内,它通常用于二分类问题中,在神经网络的早期被广泛使用。然而,Sigmoid函数容易导致梯度消失的问题,这是因为当输入值在绝对值较大时,Sigmoid函数的梯度趋近于零。这会导致在深层网络中,梯度在反向传播过程中逐层减小,使得网络难以学习。解决内部协变量偏移的规范化技术是在Sigmoid函数之后发展起来的,目的是为了改善训练过程的稳定性和效率。

反向传播是一种在神经网络中计算误差梯度的方法,然后这些梯度用于通过梯度下降算法更新网络的权重。以一个简单的三层神经网络为例,包括输入层、一个隐藏层和输出层:

  1. 前向传播:输入数据被送入网络并通过每一层,每一层的神经元将上一层的输出与权重相乘,加上偏置项,然后通过激活函数传递给下一层。

  2. 计算误差:网络的最终输出与真实值之间的差异被计算出来,通常使用损失函数来完成。

  3. 反向传播:误差被反向传播回网络,从输出层开始,逐层向后计算每层权重对误差的影响(即梯度)。这是通过链式法则完成的。

  4. 权重更新:一旦梯度被计算出来,就可以更新网络的权重,以期在下一次前向传播时减少误差。

反向传播使得深度学习能够在复杂的非线性模型中学习到有效的特征表示。

如果是全连接的神经网络(也称为密集连接网络),则第一隐藏层中的每个神经元的输出确实会连接到第二隐藏层的每个神经元。这意味着第二隐藏层的每个神经元会接收到第一隐藏层所有神经元输出的加权和(每个连接都有一个权重),加上一个偏置项,然后通过激活函数处理以产生自己的输出。这种结构使得网络可以捕捉到复杂的特征和模式。

加权和过程中使用的权重和偏置在反向传播过程中会被更新。反向传播算法通过计算损失函数相对于各权重的梯度,来指导如何调整权重和偏置,以减少模型的预测误差。这意味着每个神经元的输入加权和以及激活函数的选择都直接影响到网络的学习和性能表现。

在每一层都有n个神经元的全连接网络中,第一层每个神经元都有一个对应的输入权重和偏置。对于第二层(以及之后的每一层),每个神经元将接收前一层n个神经元的输出,这意味着会有n*n个权重(因为每个前层神经元的输出都要乘以一个权重,然后所有这些加权输出相加),以及n个偏置(每个神经元一个)。这还不包括激活函数可能引入的任何参数。

大多数激活函数,比如Sigmoid、ReLU、Tanh等,本身不含可训练参数,因此它们不直接参与权重更新。激活函数的主要作用是引入非线性,使得神经网络可以学习和模拟复杂的函数。在反向传播过程中,激活函数影响梯度的计算,但它们内部没有参数需要更新。然而,存在一些特殊的激活函数,比如Parametric ReLU (PReLU),它们包含可学习的参数,这些参数会在训练过程中更新。

 ReLU中的max(a,b)函数称为最大值函数,它输出两个数中较大的那个。

Rectified Linear Unit翻译为线性整流函数,最初是在神经网络领域被提出和使用的,作为一种激活函数。它因其简单和高效的特性而受到青睐,尤其是在深度学习模型中。ReLU通过解决梯度消失问题,加速了深度网络的训练过程。由于它在正数部分保持线性,而在负数部分输出为零,这使得网络在训练时更加稳定,并且能够处理非线性问题。

激活函数发展史:

最初引入激活函数是为了解决线性模型无法处理非线性问题的局限。没有激活函数,无论多少层的神经网络其实都只是一个线性回归模型,这大大限制了网络的表达能力。通过在神经元的输出上应用非线性激活函数,神经网络可以学习复杂的非线性关系,从而提高模型的灵活性和效能,使其能够解决更广泛的任务。

激活函数的引入使得网络能够学习复杂的非线性关系,即使激活函数本身没有可训练的参数。

尽管tanh(双曲正切函数)和tan(正切函数)在形式上相似,但它们基于不同的数学原理,不能直接通过tan(x)来计算tanh(x)。

激活函数的发展历程可以看作是对深度学习模型中遇到的问题的一系列迭代解决办法。最初,Sigmoid和Tanh因能够模拟神经元的开/关行为而被广泛使用,但它们很快因梯度消失问题而受限。接着,ReLU函数因其解决了梯度消失问题而成为默认选择,它简单且在正区间内保持激活,有助于加速收敛,然而,ReLU在负值区域完全不激活,导致了神经元死亡问题。为进一步优化,Leaky ReLU和PReLU等变体被提出,它们保证了负值区域也有微小的梯度。随着时间的推移,更多激活函数被提出,以适应不同的网络架构和训练需求。

2.2看一下项目目录

内容很多,我先让GPT看着目录解释一下(从项目文件命名习惯的角度推断一下各个文件是做什么用的):

这张图片展示的是GitHub上的一个项目仓库,具体到每个文件或文件夹,根据命名习惯,我们可以推测它们的作用如下:

  1. datasets - 这个文件夹可能包含项目使用的数据文件。对于SleepGPT项目,这些可能包括用于学习或做出预测的睡眠研究数据,如EEG读数。

  2. figures - 这个文件夹可能用于存储由项目生成或使用的图像、图表或插图。这可能包括模型性能的可视化或数据表示图。

  3. models - 这个文件夹很可能包含不同的模型文件或脚本,定义SleepGPT模型的架构,或者可能包含预训练的模型。

  4. .DS_Store - 这是由macOS创建的系统文件,存储了文件夹的自定义属性,如图标位置或背景图像选择。它与项目功能无关。

  5. .gitignore - 这个文件列出了Git应该忽略的所有文件或文件模式。它可能包括编译的Python文件(.pyc)、数据集或其他不打算提交到仓库的文件。

  6. LICENSE - 这包含了项目的许可协议,根据仓库信息,许可是Apache 2.0。这种许可允许非常宽松的使用,限制很少。

  7. README.md - 这通常是提供项目概览的文件,包括安装指南、使用示例以及用户或开发人员需要知道的其他重要信息。

  8. basemodel.py - 这个脚本可能定义了基础模型结构或类,项目中的其他模型可能继承或使用这些结构或类。

  9. eegreader.py - 鉴于项目的上下文,这个脚本很可能用于读取EEG数据,可能包括解析、预处理或将EEG数据格式化为适合输入模型的结构。

  10. gpt_longseq.py - 这可能是一个处理GPT长序列的脚本,因为长序列由于内存限制而具有挑战性。该脚本可能包括高效处理长序列的方法。

  11. gpt_transformers.py - 这个脚本很可能包含转换器模型的实现或为这个项目的需求量身定制的转换器模型修改。

  12. main_baseline.py - 这可能是一个建立模型基准性能指标的脚本,使用传统或更简单的方法与SleepGPT进行比较。

  13. main_evaluate_slm.py - 这个脚本可能用于评估项目中的睡眠语言模型(SLM)部分,可能包括测试或验证程序。

  14. main_seqclassification_cv.py - 这可能是一个主脚本,用于带交叉验证的序列分类任务,交叉验证是一种评估统计分析结果如何推广到独立数据集的方法。

  15. main_sleepmodel.py - 很可能是运行SleepGPT模型进行睡眠分期或睡眠障碍诊断的主脚本,可能包括训练、测试和应用模型。

  16. modules.py - 这个文件可能包含在项目中多个脚本使用的各种自定义模块或函数。

  17. torchutils.py - 这个脚本可能包括使用PyTorch的实用函数或类,PyTorch是一个深度学习框架,项目使用的是Python(根据仓库信息显示的唯一语言是Python,PyTorch是Python库)。

每个文件和目录在组织代码和项目所需资源方面发挥着作用,SleepGPT项目的目标是使用语言模型处理与睡眠数据相关的任务。

3.dataset

dataset中是这样的几个py文件,我听别人说似乎是运行py文件之后可以生成数据文件

这张图片显示的是GitHub仓库中名为“sleepgpt”的项目的“datasets”文件夹内容,里面包含了四个Python脚本文件:

  1. capreader.py - 从文件名来看,“cap”可能是某种数据格式或者数据来源的缩写。这个脚本很可能是用来读取和处理与“cap”相关的数据文件,并将其转换成模型可以使用的格式。

  2. massreader.py - 同样,这里的“mass”可能代表了一个特定的数据集或者数据类型。这个脚本用于读取“mass”数据,处理后可能会输出为模型训练或测试所需的数据格式。

  3. shhsreader.py - “shhs”可能代表“Sleep Heart Health Study”的缩写,这是一个著名的睡眠研究数据集。这个脚本可能用于解析该研究提供的数据文件。

  4. sleepedfreader.py - “edf”可能代表“European Data Format”,这是一种广泛用于睡眠研究和医疗领域记录生理信号的数据格式。这个脚本很可能是用于读取和处理EDF文件格式的数据。

通常情况下,这些读取器(reader)脚本的作用是把原始数据文件读入,进行必要的预处理,然后转换成一种格式,供模型进一步处理或分析。这些处理步骤可能包括数据清洗、格式化、特征提取等。一旦处理完成,数据就可以被用于训练、验证或测试机器学习模型,以实现如自动睡眠分期、异常检测等任务。

 那么现在我们去看看这几个py文件里面的代码

3.1 dataset/capreader.py

3.1.1代码解释

先看capreader.py


import os
import time
import glob
import mne
import numpy as np
from scipy import interpolate


# Have to manually define based on the dataset
ann2label = {
    "W": 0,
    "S1": 1,
    "S2": 2,
    "S3": 3, "S4": 3, # Follow AASM Manual
    "REM": 4, "R": 4,
    "MT": 5,
}

# Label values
W = 0       # Stage AWAKE
N1 = 1      # Stage N1
N2 = 2      # Stage N2
N3 = 3      # Stage N3
REM = 4     # Stage REM
MOVE = 5    # Movement
UNKNOWN = 5     # Unknown

stage_dict = {
    "W": W,
    "S1": N1,
    "S2": N2,
    "S3": N3,
    "REM": REM,
    "MOVE": MOVE,
    "UNKNOWN": UNKNOWN,
}

disorder_prefix = {
    "brux": 1,
    "ins": 2,
    "narco": 3,
    "nfle": 4,
    "plm": 5,
    "rbd": 6,
    "sdb": 7,
    "n": 0,
}


def get_timeStamp(str):
    try:
        timeArray = time.strptime(str, "%H:%M:%S")
    except:
        timeArray = time.strptime(str, "%H.%M.%S")
    return timeArray.tm_hour*3600 + timeArray.tm_min*60 + timeArray.tm_sec


def resample(signal, signal_frequency, target_frequency):
    resampling_ratio = signal_frequency / target_frequency
    x_base = np.arange(0, len(signal))

    interpolator = interpolate.interp1d(x_base, signal, axis=0, bounds_error=False, fill_value='extrapolate',)

    x_interp = np.arange(0, len(signal), resampling_ratio)

    signal_duration = signal.shape[0] / signal_frequency
    resampled_length = round(signal_duration * target_frequency)
    resampled_signal = interpolator(x_interp)
    if len(resampled_signal) < resampled_length:
        padding = np.zeros((resampled_length - len(resampled_signal), signal.shape[-1]))
        resampled_signal = np.concatenate([resampled_signal, padding])

    return resampled_signal


def load_annotations(ann_fname, interval_time=30):
    f = open(ann_fname, 'r')
    lines = f.readlines()
    f.close()
    start_index = 9999
    style = 0
    temp_stage = []
    temp_time = []

    start_time = 0
    for i in range(len(lines)):
        if i < start_index:
            if lines[i] == "Sleep Stage	Time [hh:mm:ss]	Event	Duration[s]	Location\n":
                start_index = i
                style = 1
            elif lines[i] == "Sleep Stage	Position	Time [hh:mm:ss]	Event	Duration[s]	Location\n":
                start_index = i
                style = 2
            elif lines[i] == "Sleep Stage	Position	Time [hh:mm:ss]	Event	Duration [s]	Location\n":
                start_index = i
                style = 2
                print ("in")
            else:
                continue

        if i == start_index + 1:
            temp = lines[i].split("\t")
            start_time = get_timeStamp(temp[style])
            
        if i > (interval_time / 30 - 1) / 2 + start_index and i + (interval_time / 30 - 1) / 2 < len(lines):
            if style == 0:
                print(f"{ann_fname} time column error")
                return 1, 1
            
            temp = lines[i].split("\t")
            try:
                if temp[style+2] != '30':
                    continue
                timeStamp = get_timeStamp(temp[style])
                temp_time.append(timeStamp)
                temp_stage.append(temp[0])
            except:
                if i - start_index > 3:
                    print(f"{ann_fname} duration warning")
                    break
                else:
                    print(f"{ann_fname} duration error")
                    exit()
    times = np.array(temp_time)
    if times.shape[0] == 0:
        print(f"{ann_fname} load txt error")
        return 1, 1

    return temp_stage, times


def load_eegdata_cap(psg_fname, ann_fname, select_ch=['C4-A1', 'C4A1', 'C3-A2', 'C3A2'], target_fs=100.):
    """
    https://github.com/emadeldeen24/AttnSleep
    """

    labels, times = load_annotations(ann_fname)
    if labels == 1 and times == 1:
        print(f"load annotations from {ann_fname} failed")
        return
    labels = np.array([ann2label[x] for x in labels])

    data = mne.io.read_raw_edf(psg_fname)
    sampling_rate = data.info['sfreq']
    try:
        signals_raw = data.get_data(picks=select_ch[0])[0]
    except:
        try:
            signals_raw = data.get_data(picks=select_ch[1])[0]
        except:
            try:
                signals_raw = data.get_data(picks=select_ch[2])[0]
            except:
                try:
                    signals_raw = data.get_data(picks=select_ch[3])[0]
                except:
                    try:
                        signals_0 = data.get_data(picks='A1')[0]
                        signals_1 = data.get_data(picks='C4')[0]
                        signals_raw = signals_1 - signals_0
                    except:
                        print(f"{psg_fname} does not have channel {select_ch[0]} or {select_ch[2]}")
                        return
    signals = signals_raw
    if sampling_rate != target_fs:
        signals = resample(signals_raw, sampling_rate, target_fs)

    #### dealing time ##########
    #timeArray = time.strptime(data.info['meas_date'], "%Y-%m-%d %H:%M:%S+00:00")
    #start_time = timeArray.tm_hour*3600 + timeArray.tm_min*60 + timeArray.tm_sec
    start_time = data.info['meas_date'].hour * 3600 + data.info['meas_date'].minute * 60 + data.info['meas_date'].second
    for i in range(len(times)):
        if times[i] >= start_time:
            times[i] = times[i] - start_time
        else:
            times[i] = 24*3600 - start_time + times[i]

    signals_epoched = []
    for i in range(len(times)):
        epoch_begin = int(times[i] * target_fs)
        epoch_end = int((times[i] + 30) * target_fs)
        if epoch_end > len(signals):
            print(f"{psg_fname} time overflow, rest epoch {len(times) - i}\n")
            break
        signal_epoch = signals[epoch_begin:epoch_end]
        signals_epoched.append(signal_epoch)
    signals = np.vstack(signals_epoched)

    # Get epochs and their corresponding labels
    x = signals.astype(np.float32)
    y = labels.astype(np.int32)

    print(x.shape)
    print(y.shape)
    # assert len(x) == len(y)
    y_full = y.copy()
    y = y[:len(x)]

    # Remove movement and unknown
    remove_idx = np.where(y >= 5)[0]
    if len(remove_idx) > 0:
        select_idx = np.setdiff1d(np.arange(len(x)), remove_idx)
        x = x[select_idx]
        y = y[select_idx]

    remove_idx = np.where(y_full >= 5)[0]
    if len(remove_idx) > 0:
        select_idx = np.setdiff1d(np.arange(len(y_full)), remove_idx)
        y_full = y_full[select_idx]

    data_dict = {
        "x": x,
        "y": y,
        "y_full": y_full,
        "fs": sampling_rate
    }

    return data_dict


def load_npz_file(npz_file):
    """Load data and labels from a npz file."""
    with np.load(npz_file) as f:
        data = f["x"]
        labels = f["y"]
        sampling_rate = f["fs"]
    return data, labels, sampling_rate

def load_npz_list_files(npz_files):
    """Load data and labels from list of npz files."""
    data = []
    labels = []
    fs = None
    for npz_f in npz_files:
        print("Loading {} ...".format(npz_f))
        tmp_data, tmp_labels, sampling_rate = load_npz_file(npz_f)
        if fs is None:
            fs = sampling_rate
        elif fs != sampling_rate:
            raise Exception("Found mismatch in sampling rate.")

        # Reshape the data to match the input of the model - conv2d
        # tmp_data = np.squeeze(tmp_data)
        # tmp_data = tmp_data[:, :, np.newaxis, np.newaxis]
        
        # # Reshape the data to match the input of the model - conv1d
        tmp_data = tmp_data[:, :, np.newaxis]

        # Casting
        tmp_data = tmp_data.astype(np.float32)
        tmp_labels = tmp_labels.astype(np.int32)

        data.append(tmp_data)
        labels.append(tmp_labels)

    return data, labels

def load_subdata_preprocessed(datapath, subject):
    npz_f = os.path.join(datapath, subject+'.npz')
    data, labels, fs = load_npz_file(npz_f)
    return data, labels

def load_dataset_preprocessed(datapath, subsets=['shhs1'], n_subjects=None):
    if isinstance(subsets, str):
        subsets = [subsets]
    npzfiles = []
    for subset in subsets:
        subset_npzfiles = glob.glob(os.path.join(datapath, subset, "*.npz"))
        [npzfiles.append(npz_f) for npz_f in subset_npzfiles]
    npzfiles.sort()
    if n_subjects is not None:
        npzfiles = npzfiles[:n_subjects]
    subjects = [os.path.basename(npz_f)[:-4] for npz_f in npzfiles]
    data, labels = load_npz_list_files(npzfiles)
    return data, labels, subjects


if __name__ == '__main__':

    datapath = 'e:/eegdata/sleep/cap/'
    savepath = datapath + 'processed/'
    os.makedirs(savepath, exist_ok=True)

    record = open(datapath + 'RECORDS', 'r')
    lines = record.readlines()
    
    annotations = []
    ann_f = open(savepath+'annotations.txt', 'w')
    sub_f = open(savepath+'subject_labels.txt', 'w')
    for i in range(len(lines)):
        temp = lines[i].split('.')
        ann_fname = datapath + temp[0] + '.txt'
        edf_fname = datapath + temp[0] + '.edf'

        subject = os.path.basename(edf_fname)[:-4]
        for prefix, label in disorder_prefix.items():
            if subject.startswith(prefix):
                subject_label = label
                break
        sub_f.write(f"{subject_label}\n")
        
        print('Load and extract continuous EEG into epochs for subject '+subject)
        data_dict = load_eegdata_cap(edf_fname, ann_fname)
        annotations.append(data_dict["y_full"])
        ann_f.write(",".join([f"{ann}" for ann in data_dict["y_full"]]) + "\n")

        np.savez(savepath+subject+'.npz', **data_dict)
    ann_f.close()
    sub_f.close()

这段代码是一个Python脚本,用于从多通道EEG信号的EDF文件中读取和处理数据,并将其与睡眠阶段注释相关联。它使用了mne库来读取EDF文件,这是一种常用于神经科学数据分析的Python库。以下是代码的主要功能解释:

  1. 变量定义

    • ann2labelstage_dict是字典,用于将文本形式的睡眠阶段(如"W"代表清醒状态)映射到整数标签。
    • disorder_prefix是另一个字典,用于将某种睡眠障碍的前缀映射到一个整数标识。
  2. 函数定义

    • get_timeStamp函数将时间字符串转换为秒数。
    • resample函数用于将信号从一个采样频率重新采样到目标频率。
    • load_annotations函数用于加载注释文件,这个文件中包含了睡眠阶段的标记和时间戳。
    • load_eegdata_cap函数用于加载EDF文件,从中提取EEG信号,并将其与睡眠阶段的注释相关联。它也负责信号的重新采样(如果需要的话)。
    • load_npz_fileload_npz_list_files函数用于加载存储在.npz文件中的数据和标签。
    • load_subdata_preprocessedload_dataset_preprocessed函数用于加载预处理的数据集。
  3. 数据处理逻辑

    • 脚本首先从注释文件中加载睡眠阶段的标记和对应的时间戳。
    • 然后,它读取EEG数据,如果需要,将数据重新采样到指定的采样率。
    • 接着,脚本将时间戳用于将连续的EEG信号切分成与睡眠阶段相对应的30秒的片段(epoch)。
    • 这些信号片段和它们的标签被转换为适合机器学习模型输入的格式。
  4. 数据保存逻辑

    • 脚本在处理每个EDF文件后,会将处理好的信号片段和相应的标签保存为.npz格式,便于后续的分析或模型训练。
  5. 主执行逻辑if __name__ == '__main__':):

    • 这部分代码负责遍历数据集中的所有记录,加载EEG文件和注释文件,处理数据,并将结果保存到指定路径。

整个脚本的目的是自动化地处理和准备用于睡眠研究的EEG数据,这些数据可以用于训练和验证睡眠阶段分类模型。通过将原始的EEG信号和睡眠阶段标签关联起来,研究人员可以进一步分析睡眠模式或训练自动化的睡眠分期算法。

3.1.2问题解答

1.多通道EEG信号的EDF文件是什么

EEG(脑电图)信号是通过在头皮上放置电极来记录大脑电活动的一种方法。多通道EEG指的是同时从头部多个位置收集的脑电信号,每个位置的电极记录下的信号构成一个通道。

下面是AI生成的图,可以帮助理解

EDF(European Data Format,欧洲数据格式)文件是一种用于存储生物医学信号数据,如EEG、心电图(ECG)、睡眠多导睡眠图等的标准文件格式。EDF文件广泛用于睡眠研究和医疗领域,因为它允许同步记录多种生理信号,并且是一个开放的非专有格式

多通道EEG信号的EDF文件会包含以下信息:

  • 头部信息:包括记录的日期和时间、记录持续时间、使用的电极和它们的位置、各通道的标签(如C4-A1指的是C4和A1两个电极之间的差分信号),以及其他记录设备和患者信息。
  • 数据记录:每个通道的生理信号数据,以一定的采样频率数字化和记录。这些数据可以用于分析大脑的电活动,比如在睡眠研究中用来确定不同的睡眠阶段。

在睡眠研究中,通过分析EDF文件中的EEG信号,研究人员可以识别例如清醒、轻度睡眠、深度睡眠和REM(快速眼动)睡眠等不同的睡眠阶段。这些信息对于理解睡眠模式、诊断睡眠障碍和评估治疗效果等都是非常重要的。

EDF文件的后缀名通常是“.edf”。这种文件格式用于存储医学和生物医学测量数据,如睡眠监测中的脑电图(EEG)数据。EDF代表European Data Format,即欧洲数据格式。

EOG(眼电图)

我找到一篇论文,是这么说的现在还有一种EOG

近年来,使用脑电图 (EEG) 信号的自动睡眠分期方法取得了具有竞争力的性能。然而,脑电信号的采集既繁琐又不方便。因此,我们提出了一种使用眼电图(EOG)信号的新睡眠分期方法,该方法比EEG更易于获取。一个双尺度卷积神经网络首先从原始EOG信号中提取周期性的时间等效特征。然后,递归神经网络捕获长期顺序信息。所提出的方法在来自两个开放获取数据库的 101 个整夜睡眠数据上进行了验证,蒙特利尔睡眠研究档案和 Sleep-EDF,总体准确率分别为 81.2% 和 76.3%。结果与那些用脑电图信号训练的模型相当。此外,与六种最先进方法的比较进一步证明了所提出的方法的有效性。总体而言,这项研究为睡眠监测提供了一条新的途径。

前沿 |EOGNET:一种基于单通道EOG信号的睡眠阶段分类深度学习模型 (frontiersin.org)

2.mne库

mne是一个用于Python的开源库,专门用于处理神经电生理数据(例如脑电图EEG、磁脑图MEG、以及其他相关数据),以及神经成像数据(如功能磁共振成像fMRI)。MNE全称是Magnetoencephalography (MEG) Neuroimaging and Electrophysiology (EEG),它提供了一系列强大的工具来进行数据加载、信号处理、统计分析、可视化等任务。

mne库特别适合于处理和分析时间序列数据,并且包含了许多用于数据预处理的功能,如滤波、去噪、伪迹去除等。此外,它也支持多种神经电生理数据格式,包括但不限于EDF。

总的来说,mne库是神经科学研究者和数据分析师处理和分析EEG、MEG和fMRI数据的强大工具。它的社区活跃,定期更新,并有详细的文档和教程支持新用户学习如何使用它。

官方文档:MNE — MNE 1.6.1 documentation

Github页面:GitHub - mne-tools/mne-python: MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python

论坛:MNE Forum 

3.时间戳

“时间戳”指的是代表特定时间点的数据。它可以是日期和时间的表示,也可以是自一个特定时刻(如UNIX纪元:1970年1月1日)以来的秒数。

在这个脚本中,时间戳是用来标记每个epoch开始的具体时间点,以确保数据的准确切分和标注。

4.npz

至于npz文件,它是一种NumPy文件格式,用于存储NumPy数组。npz文件是一个压缩的包,里面可以包含多个NumPy数组。这种文件格式并不是专门为transformers框架设计的,而是在需要存储和传输大量NumPy数组数据时的一种便捷方式。在机器学习和深度学习项目中,npz文件因其高效的存储方式而被广泛使用,不限于transformers框架。这种格式便于将数据集整理、打包,并在不同的计算任务中重复使用,但它不是任何特定框架的专属格式。

3.1.3概述

        capreader.py是用来通过注释文件和EDF(可能特指脑电图EEG的EDF)文件来生成模型训练用的npz文件的。

        如果没有脑电波的EDF文件可以忽略这个capreader。

3.2 dataset/massreader.py

先让AI解读一下吧。

这段代码是一个处理睡眠EEG数据的Python脚本,特别是针对MASS数据库的数据。它主要由几个关键部分组成:

  1. 数据标签映射(ann2label:这是一个字典,定义了睡眠阶段与数值标签之间的映射。这样做是为了将文本标签(如"Sleep stage W")转换为模型可以理解的数值标签。

  2. 重采样(resample函数):这个函数用于将信号从原始采样频率转换到目标采样频率。这是通过插值来实现的,如果重采样后的信号长度小于目标长度,则会添加填充。

  3. 加载EEG数据和注释(load_eegdata_mass函数):这个函数负责从EDF文件中读取EEG信号和睡眠阶段注释。它使用pyedflib库来读取EDF文件,然后根据注释提取每个睡眠阶段的信号,并将其划分为30秒的epoch。函数最后返回一个包含信号、标签和其他信息的字典。

  4. 数据加载和预处理(load_npz_fileload_npz_list_files等函数):这些函数用于加载预处理后的数据,支持从多个.npz文件中加载数据和标签。这对于批量处理和模型训练非常有用。

  5. 数据保存和组织:脚本最后部分包含代码,用于遍历指定数据集的子集,对每个主题的EEG数据进行处理,提取特征,然后将处理后的数据保存为.npz格式。同时,还会保存所有主题的睡眠阶段标签到一个文本文件中。

整个脚本展示了一个从读取原始EEG信号到提取特征、重采样、划分epoch、标记睡眠阶段、保存处理后数据的完整流程,适用于睡眠研究和睡眠阶段分类任务。

SleepstageW是代码里面的名词吗

还真是

定义了5个睡眠阶段

W,1,2,3,R,?

所以简单地说它可以把EEF转成一种固定长度频率(格式)的数字样本。

3.3猜测

合理推测4个py文件对应4种原始格式的数据输出都是npz,比如说massreader.py,应该就是吧mass数据库的.edf转换为npz.

但是有一个问题是为什么这里有psg和ann都要做sort

3.2.1实践

启动一下massreader.py试一试

4.models

我觉得这个应该是重头戏

4.1seqsleepnet.py

这段代码定义了一个名为SeqSleepNet的神经网络模型,旨在处理序列化的睡眠数据。模型由两个主要部分组成:一个编码器和一个长短期记忆(LSTM)网络。编码器用于提取特征,而LSTM部分负责处理序列依赖性。模型还包含一个分类头,用于将LSTM的输出映射到预定的类别上。代码中提供了参数初始化、编码器冻结选项以及前向传播逻辑。此外,还展示了如何实例化并使用该模型的示例。

包含多个关键概念:

  • 序列化:处理时间序列数据,如连续的睡眠阶段信号。
  • 编码器:将输入数据(如EEG信号)转换成有用的表示(特征)。
  • LSTM网络:一种长短期记忆网络,能够捕捉序列数据中的长期依赖性。
  • 序列依赖性:指序列中先前的元素对后续元素的影响。
  • 分类头:网络的最后一部分,用于将特征映射到类别标签上。
  • 参数初始化:为模型的学习设置起始点。
  • 编码器冻结选项:决定是否更新编码器的权重,以利用预先训练的特征。
  • 前向传播逻辑:定义了数据如何通过网络流动进行预测

长短期记忆网络(LSTM)是一种特殊的循环神经网络(RNN),能够学习长期依赖信息。LSTM的关键是其内部结构,包括三个门(输入门、遗忘门、输出门)和一个单元状态,这些结构使得LSTM可以在长序列中有效地保持和传递信息,解决了传统RNN在处理长序列数据时面临的梯度消失或梯度爆炸问题。

传统的RNN(循环神经网络)是一种用于处理序列数据的深度学习模型,它通过循环连接能够记忆之前的信息,并利用这些信息来处理当前的输入。RNN特别适合于文本处理、语音识别等序列任务。但是,它们在长序列上容易遇到梯度消失或梯度爆炸的问题,这限制了它们学习长期依赖性的能力。

RNN(循环神经网络)技术是基于早期神经网络的研究迭代发展而来的。早期神经网络主要关注于处理静态输入,如感知机和前馈网络。随着对序列数据处理需求的增长,比如语言模型和时间序列分析,研究者开始探索能够处理序列依赖性的模型。RNN通过引入循环连接,使得网络能够保持状态,从而处理序列数据中的时间依赖性,标志着神经网络从静态处理向序列处理的重要转变。

CNN(卷积神经网络)和RNN(循环神经网络)并不是一个直接的技术迭代关系,而是两种针对不同类型数据和任务设计的神经网络结构。CNN主要用于处理具有空间关联的数据,如图像和视频,而RNN设计用于处理序列数据,如文本和时间序列。两者都是深度学习技术的重要组成部分,但它们解决问题的方式和应用领域各不相同。

CNN是通过全连接层迭代而来的,就是输入层和全连接层间加了卷积层和池化层。

CNN的处理流程大致是:首先通过卷积层提取输入数据(如图像)的特征,然后通过池化层降低数据的空间维度,减少后续计算的复杂性。经过这一系列的处理后,数据会传递到全连接层,其中进行权重求和、加偏置并通过激活函数,以进行最终的分类或回归任务。

CNN网络有时被描述为沙漏型,因为它们通常从宽(即输入层,接收大量输入数据)开始,通过卷积和池化层逐渐减少维度,达到一个“瓶颈”,然后可能通过上采样或全连接层再次扩展,尤其是在需要生成图像或重建特征的任务中。这种结构有助于有效地提取和压缩信息,然后用于特定任务的决策。

在一些特定类型的CNN结构中,尤其是用于图像分割、图像生成等任务的网络,全连接层确实可能会再次扩展,增加神经元的数量。这种设计允许网络从压缩的特征中恢复更多细节信息,为最终的输出提供丰富的上下文。然而,并非所有CNN都遵循这种设计;它依赖于特定任务的需求和网络架构的选择。

,编码器(encoder)是一个预定义的网络结构,用于将输入数据转换成有用的表示(特征)。具体的编码器实现细节(如DeepSleepNetTinySleepNet)并未在这段代码中给出。编码器的作用是提取输入数据(如EEG信号)的特征,这些特征随后被送入LSTM网络进行序列处理。编码器如何具体提取特征取决于它的内部结构,比如卷积层用于捕获空间特征,池化层用于降维。由于具体的编码器实现细节不在代码片段中,具体规则需查看相应的DeepSleepNetTinySleepNet实现。

所以我们就要看sleepnet.py里这个DeepSleepNet是怎么实现的。

4.2sleepnet.py

# -*- coding:utf-8 -*-

import numpy as np
import torch
import torch.nn as nn
from functools import reduce
from operator import __add__


class Conv2dSamePadding(nn.Conv2d):
    def __init__(self,*args,**kwargs):
        super(Conv2dSamePadding, self).__init__(*args, **kwargs)
        self.zero_pad_2d = nn.ZeroPad2d(reduce(__add__,
            [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in self.kernel_size[::-1]]))

    def forward(self, input):
        return  self._conv_forward(self.zero_pad_2d(input), self.weight, self.bias)
    

class Conv2dBnReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
        if isinstance(stride, int): stride = (stride, stride)
        if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size)
        ConvLayer = Conv2dSamePadding(
            in_channels, out_channels, kernel_size, stride, bias=False, **kwargs
        )
        super(Conv2dBnReLU, self).__init__(
            ConvLayer,
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        
class DeepSleepNet(nn.Module):
    """
    Reference:
    A. Supratak, H. Dong, C. Wu, and Y. Guo, "DeepSleepNet: A Model for Automatic
    Sleep Stage Scoring Based on Raw Single-Channel EEG," IEEE Trans Neural Syst 
    Rehabil Eng, vol. 25, no. 11, pp. 1998-2008, 2017.
    https://github.com/akaraspt/deepsleepnet
    """
    def __init__(
        self, n_classes, n_timepoints, dropout = 0.5,
        # Conv layers
        n_filters_1 = 64, filter_size_1 = 50, filter_stride_1 = 6,
        n_filters_2 = 64, filter_size_2 = 400, filter_stride_2 = 50,
        pool_size_11 = 8, pool_stride_11 = 8, 
        pool_size_21 = 4, pool_stride_21 = 4,
        n_filters_1x3 = 128, filter_size_1x3 = 8,
        n_filters_2x3 = 128, filter_size_2x3 = 6,
        pool_size_12 = 4, pool_stride_12 = 4, 
        pool_size_22 = 2, pool_stride_22 = 2,
    ):
        super().__init__()
        self.conv1 = nn.Sequential(
            Conv2dBnReLU(1, n_filters_1, (filter_size_1, 1), (filter_stride_1, 1)),
            nn.MaxPool2d((pool_size_11, 1), (pool_stride_11, 1)),
            nn.Dropout(dropout),
            Conv2dBnReLU(n_filters_1,   n_filters_1x3, (filter_size_1x3, 1), stride=1),
            Conv2dBnReLU(n_filters_1x3, n_filters_1x3, (filter_size_1x3, 1), stride=1),
            Conv2dBnReLU(n_filters_1x3, n_filters_1x3, (filter_size_1x3, 1), stride=1),
            nn.MaxPool2d((pool_size_12, 1), (pool_stride_12, 1)),
        )
        self.conv2 = nn.Sequential(
            Conv2dBnReLU(1, n_filters_2, (filter_size_2, 1), (filter_stride_2, 1)),
            nn.MaxPool2d((pool_size_21, 1), (pool_stride_21, 1)),
            nn.Dropout(dropout),
            Conv2dBnReLU(n_filters_2,   n_filters_2x3, (filter_size_2x3, 1), stride=1),
            Conv2dBnReLU(n_filters_2x3, n_filters_2x3, (filter_size_2x3, 1), stride=1),
            Conv2dBnReLU(n_filters_2x3, n_filters_2x3, (filter_size_2x3, 1), stride=1),
            nn.MaxPool2d((pool_size_22, 1), (pool_stride_22, 1)),
        )
        self.drop1 = nn.Dropout(dropout)

        outlen_conv1 = n_timepoints // filter_stride_1 // pool_stride_11 // pool_stride_12
        outlen_conv2 = n_timepoints // filter_stride_2 // pool_stride_21 // pool_stride_22
        outlen_conv = outlen_conv1*n_filters_1x3 + outlen_conv2*n_filters_2x3

        self.feature_dim = outlen_conv
        self.classifier = nn.Linear(outlen_conv, n_classes) if n_classes > 0 else nn.Identity()

        self._reset_parameters()

    def _reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.xavier_uniform_(m.weight, gain=1)
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x1 = x1.view(x1.size(0), -1) # flatten (b, c, t, 1) -> (b, c*t)
        x2 = x2.view(x2.size(0), -1) # flatten (b, c, t, 1) -> (b, c*t)
        x = torch.cat((x1, x2), dim=1) # concat in feature dimention
        x = self.drop1(x)
        x = self.classifier(x)
        return x


class TinySleepNet(nn.Module):
    """
    Reference:
    A. Supratak and Y. Guo, "TinySleepNet: An Efficient Deep Learning Model
    for Sleep Stage Scoring based on Raw Single-Channel EEG," Annu Int Conf
    IEEE Eng Med Biol Soc, vol. 2020, pp. 641-644, Jul 2020.
    https://github.com/akaraspt/tinysleepnet
    """
    def __init__(
        self, n_classes, n_timepoints, dropout = 0.5,
        # Conv layers
        n_filters_1 = 128, filter_size_1 = 50, filter_stride_1 = 6,
        pool_size_1 = 8, pool_stride_1 = 8, 
        n_filters_1x3 = 128, filter_size_1x3 = 8,
        pool_size_2 = 4, pool_stride_2 = 4, 
    ):
        super().__init__()
        self.conv1 = nn.Sequential(
            Conv2dBnReLU(1, n_filters_1, (filter_size_1, 1), (filter_stride_1, 1)),
            nn.MaxPool2d((pool_size_1, 1), (pool_stride_1, 1)),
            nn.Dropout(dropout),
            Conv2dBnReLU(n_filters_1,   n_filters_1x3, (filter_size_1x3, 1), stride=1),
            Conv2dBnReLU(n_filters_1x3, n_filters_1x3, (filter_size_1x3, 1), stride=1),
            Conv2dBnReLU(n_filters_1x3, n_filters_1x3, (filter_size_1x3, 1), stride=1),
            nn.MaxPool2d((pool_size_2, 1), (pool_stride_2, 1)),
            nn.Dropout(dropout)
        )

        outlen_conv1 = n_timepoints // filter_stride_1 // pool_stride_1 // pool_stride_2
        outlen_conv = outlen_conv1*n_filters_1x3

        self.feature_dim = outlen_conv
        self.classifier = nn.Linear(outlen_conv, n_classes) if n_classes > 0 else nn.Identity()

        self._reset_parameters()

    def _reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.xavier_uniform_(m.weight, gain=1)
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1) # flatten (b, c, t, 1) -> (b, c*t)
        x = self.classifier(x)
        return x


if __name__ == '__main__':

    x = torch.randn((20, 1, 3000, 1))
    model = DeepSleepNet(5, 3000)
    # model = TinySleepNet(5, 3000)
    print(model)
    y = model(x)
    print(y.shape)

 DeepSleepNet的实现主要基于卷积神经网络(CNN),通过多层卷积和池化操作提取EEG信号的特征。初始层捕获低级特征,随后的层通过更复杂的卷积操作提取高级特征。这种结构使得模型能够从原始EEG信号中学习到用于睡眠阶段分类的有效特征表示。特别地,模型采用了两个并行的卷积路径,每个路径包含多个卷积层,最后将这些特征进行合并并通过全连接层进行分类。

DeepSleepNet使用两个并行的卷积路径来从单通道EEG信号中提取特征,这两条路径设计为捕获不同尺度上的特征。第一条路径利用较小的卷积核和较快的卷积步长,而第二条路径使用较大的卷积核和较慢的卷积步长,目的是捕获EEG信号中的时间依赖性和特征。两条路径的输出在特征维度上合并,然后通过全连接层进行分类。此外,网络在关键部位使用了dropout技术来防止过拟合,参数初始化使用Kaiming正态初始化方法来优化训练过程。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值