数据、评估标准见NLB2021
https://neurallatents.github.io/
以下代码依据
https://github.com/trungle93/STNDT
原代码使用了 Ray+Config文件进行了参数搜索,库依赖较多,数据流过程不明显,代码冗杂,这里进行了抽丝剥茧,将其中最核心的部分提取出来。
1.数据处理部分
1.1 下载数据集
需要依赖 pip install dandi
downald.py
root = "D:/NeuralLatent/"
def downald_data():
from dandi.download import download
download("https://dandiarchive.org/dandiset/000128", root)
download("https://dandiarchive.org/dandiset/000138", root)
download("https://dandiarchive.org/dandiset/000139", root)
download("https://dandiarchive.org/dandiset/000140", root)
download("https://dandiarchive.org/dandiset/000129", root)
download("https://dandiarchive.org/dandiset/000127", root)
download("https://dandiarchive.org/dandiset/000130", root)
1.2 数据集预处理
需要依赖官方工具包pip install nlb_tools
主要是加载锋值序列数据,将其采样为5ms的时间槽
preprocess.py
## 以下为参数示例
# data_path = root + "/000129/sub-Indy/"
# dataset_name = "mc_rtt"
## 注意 "./data" 必须提前创建好
from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, combine_h5
def preprocess(data_path, dataset_name=None):
dataset = NWBDataset(datapath)
bin_width = 5
dataset.resample(bin_width)
make_train_input_tensors(
dataset, dataset_name=dataset_name, trial_split="train",
include_behavior=True, include_forward_pred=True, save_file=True,
save_path=f"./data/{dataset_name}_train.h5"
)
make_eval_input_tensors(
dataset, dataset_name=dataset_name, trial_split="val",
save_file=True, save_path=f"./data/{dataset_name}_val.h5"
)
combine_h5(
[f"./data/{dataset_name}_train.h5", f"./data/{dataset_name}_val.h5"],
save_path=f"./data/{dataset_name}_full.h5"
)
## './data/mc_rtt_full.h5' 将成为后续的主要分析数据
1.3 划分train-val并创建Dataset对象
读取'./data/mc_rtt_full.h5'
中的数据并创建dataset
dataset.py
import h5py
import numpy as np
import torch
from torch.utils import data
# data_path = "./data/mc_rtt_full.h5"
class SpikesDataset(data.Dataset):
def __init__(self, spikes, heldout_spikes, forward_spikes) -> None:
self.spikes = spikes
self.heldout_spikes = heldout_spikes
self.forward_spikes = forward_spikes
def __len__(self):
return self.spikes.size(0)
def __getitem__(self, index):
r"""Return spikes and rates, shaped T x N (num_neurons)"""
return self.spikes[index], self.heldout_spikes[index], self.forward_spikes[index]
def make_datasets(data_path):
with h5py.File(data_path, 'r') as h5file:
h5dict = {key: h5file[key][()] for key in h5file.keys()}
if 'eval_spikes_heldin' in h5dict: # NLB data
get_key = lambda key: h5dict[key].astype(np.float32)
train_data = get_key('train_spikes_heldin')
train_data_fp = get_key('train_spikes_heldin_forward')
train_data_heldout_fp = get_key('train_spikes_heldout_forward')
train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)
valid_data = get_key('eval_spikes_heldin')
train_data_heldout = get_key('train_spikes_heldout')
if 'eval_spikes_heldout' in h5dict:
valid_data_heldout = get_key('eval_spikes_heldout')
else:
valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)
if 'eval_spikes_heldin_forward' in h5dict:
valid_data_fp = get_key('eval_spikes_heldin_forward')
valid_data_heldout_fp = get_key('eval_spikes_heldout_forward')
valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)
else:
valid_data_all_fp = np.zeros(
(valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32
)
train_dataset = SpikesDataset(
torch.tensor(train_data).long(), # [810, 120, 98]
torch.tensor(train_data_heldout).long(), # [810, 120, 32]
torch.tensor(train_data_all_fp).long(), # [810, 40, 130]
)
val_dataset = SpikesDataset(
torch.tensor(valid_data).long(), # [810, 120, 98]
torch.tensor(valid_data_heldout).long(), # [810, 120, 32]
torch.tensor(valid_data_all_fp).long(), # [810, 40, 130]
)
return train_dataset, val_dataset
1.4 掩码mask操作
dataset.py
# Some infeasibly high spike count
UNMASKED_LABEL = -100
def mask_batch(batch, heldout_spikes, forward_spikes):
batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)
mask_ratio = 0.31254
mask_random_ratio = 0.876
mask_token_ratio = 0.527
labels = batch.clone()
mask_probs = torch.full(labels.shape, mask_ratio)
# If we want any tokens to not get masked, do it here (but we don't currently have any)
mask = torch.bernoulli(mask_probs)
mask = mask.bool()
labels[~mask] = UNMASKED_LABEL # No ground truth for unmasked - use this to mask loss
# We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context
# Most times, we replace tokens with MASK token
indices_replaced = torch.bernoulli(torch.full(labels.shape, mask_token_ratio)).bool() & mask
batch[indices_replaced] = 0
# Random % of the time, we replace masked input tokens with random value (the rest are left intact)
indices_random = torch.bernoulli(torch.full(labels.shape, mask_random_ratio)).bool() & mask & ~indices_replaced
random_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long)
batch[indices_random] = random_spikes[indices_random]
# heldout spikes are all masked
batch = torch.cat([batch, torch.zeros_like(heldout_spikes)], -1)
labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)
batch = torch.cat([batch, torch.zeros_like(forward_spikes)], 1)
labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)
# Leave the other 10% alone
return batch, labels
下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906187