Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇

数据、评估标准见NLB2021
https://neurallatents.github.io/

以下代码依据
https://github.com/trungle93/STNDT

原代码使用了 Ray+Config文件进行了参数搜索,库依赖较多,数据流过程不明显,代码冗杂,这里进行了抽丝剥茧,将其中最核心的部分提取出来。

1.数据处理部分

1.1 下载数据集

需要依赖 pip install dandi
downald.py

root = "D:/NeuralLatent/"
def downald_data():
    from dandi.download import download
    download("https://dandiarchive.org/dandiset/000128", root)
    download("https://dandiarchive.org/dandiset/000138", root)
    download("https://dandiarchive.org/dandiset/000139", root)
    download("https://dandiarchive.org/dandiset/000140", root)
    download("https://dandiarchive.org/dandiset/000129", root)
    download("https://dandiarchive.org/dandiset/000127", root)
    download("https://dandiarchive.org/dandiset/000130", root)

1.2 数据集预处理

需要依赖官方工具包pip install nlb_tools
主要是加载锋值序列数据,将其采样为5ms的时间槽
preprocess.py

## 以下为参数示例
# data_path = root + "/000129/sub-Indy/"
# dataset_name = "mc_rtt"
## 注意 "./data" 必须提前创建好

from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, combine_h5

def preprocess(data_path, dataset_name=None):
	dataset = NWBDataset(datapath)
	bin_width = 5
    dataset.resample(bin_width)
    make_train_input_tensors(
    	dataset, dataset_name=dataset_name, trial_split="train", 
        include_behavior=True, include_forward_pred=True, save_file=True,
        save_path=f"./data/{dataset_name}_train.h5"
    )
	make_eval_input_tensors(
		dataset, dataset_name=dataset_name, trial_split="val", 
		save_file=True, save_path=f"./data/{dataset_name}_val.h5"
	)
	combine_h5(
		[f"./data/{dataset_name}_train.h5", f"./data/{dataset_name}_val.h5"], 
		save_path=f"./data/{dataset_name}_full.h5"
	)

## './data/mc_rtt_full.h5' 将成为后续的主要分析数据

1.3 划分train-val并创建Dataset对象

读取'./data/mc_rtt_full.h5'中的数据并创建dataset
dataset.py

import h5py
import numpy as np
import torch
from torch.utils import data
# data_path = "./data/mc_rtt_full.h5"

class SpikesDataset(data.Dataset):
    def __init__(self, spikes, heldout_spikes, forward_spikes) -> None:
        self.spikes = spikes
        self.heldout_spikes = heldout_spikes
        self.forward_spikes = forward_spikes

    def __len__(self):
        return self.spikes.size(0)

    def __getitem__(self, index):
        r"""Return spikes and rates, shaped T x N (num_neurons)"""
        return self.spikes[index], self.heldout_spikes[index], self.forward_spikes[index]


def make_datasets(data_path):
    with h5py.File(data_path, 'r') as h5file:
        h5dict = {key: h5file[key][()] for key in h5file.keys()}
        if 'eval_spikes_heldin' in h5dict: # NLB data
            get_key = lambda key: h5dict[key].astype(np.float32)
            train_data = get_key('train_spikes_heldin')
            train_data_fp = get_key('train_spikes_heldin_forward')
            train_data_heldout_fp = get_key('train_spikes_heldout_forward')
            train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)
            valid_data = get_key('eval_spikes_heldin')
            train_data_heldout = get_key('train_spikes_heldout')
            if 'eval_spikes_heldout' in h5dict:
                valid_data_heldout = get_key('eval_spikes_heldout')
            else:
                valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)
            if 'eval_spikes_heldin_forward' in h5dict:
                valid_data_fp = get_key('eval_spikes_heldin_forward')
                valid_data_heldout_fp = get_key('eval_spikes_heldout_forward')
                valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)
            else:
                valid_data_all_fp = np.zeros(
                    (valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32
                )
        
        train_dataset = SpikesDataset(
            torch.tensor(train_data).long(),            # [810, 120, 98]
            torch.tensor(train_data_heldout).long(),    # [810, 120, 32]
            torch.tensor(train_data_all_fp).long(),     # [810, 40, 130]
        )
        val_dataset = SpikesDataset(
            torch.tensor(valid_data).long(),            # [810, 120, 98]
            torch.tensor(valid_data_heldout).long(),    # [810, 120, 32]
            torch.tensor(valid_data_all_fp).long(),     # [810, 40, 130]
        )
        return train_dataset, val_dataset

1.4 掩码mask操作

dataset.py

# Some infeasibly high spike count
UNMASKED_LABEL = -100

def mask_batch(batch, heldout_spikes, forward_spikes):
    batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)
    mask_ratio = 0.31254
    mask_random_ratio = 0.876
    mask_token_ratio = 0.527
    labels = batch.clone()
    
    mask_probs = torch.full(labels.shape, mask_ratio)
    # If we want any tokens to not get masked, do it here (but we don't currently have any)
    mask = torch.bernoulli(mask_probs)
    mask = mask.bool()
    labels[~mask] = UNMASKED_LABEL  # No ground truth for unmasked - use this to mask loss

    # We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context
    # Most times, we replace tokens with MASK token
    indices_replaced = torch.bernoulli(torch.full(labels.shape, mask_token_ratio)).bool() & mask
    batch[indices_replaced] = 0

    # Random % of the time, we replace masked input tokens with random value (the rest are left intact)
    indices_random = torch.bernoulli(torch.full(labels.shape, mask_random_ratio)).bool() & mask & ~indices_replaced
    random_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long)
    batch[indices_random] = random_spikes[indices_random]

    # heldout spikes are all masked
    batch = torch.cat([batch, torch.zeros_like(heldout_spikes)], -1)
    labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)
    batch = torch.cat([batch, torch.zeros_like(forward_spikes)], 1)
    labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)
    # Leave the other 10% alone
    return batch, labels

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906187

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值