Simple-STNDT使用Transformer进行Spike信号的表征学习(三)训练与评估

1. 评估指标

import numpy as np
from scipy.special import gammaln
import torch

def neg_log_likelihood(rates, spikes, zero_warning=True):
    """Calculates Poisson negative log likelihood given rates and spikes.
    formula: -log(e^(-r) / n! * r^n)
           = r - n*log(r) + log(n!)
    
    Parameters
    ----------
    rates : np.ndarray
        numpy array containing rate predictions
    spikes : np.ndarray
        numpy array containing true spike counts
    zero_warning : bool, optional
        Whether to print out warning about 0 rate 
        predictions or not
    
    Returns
    -------
    float
        Total negative log-likelihood of the data
    """
    assert spikes.shape == rates.shape, \
        f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}"

    if np.any(np.isnan(spikes)):
        mask = np.isnan(spikes)
        rates = rates[~mask]
        spikes = spikes[~mask]
    
    assert not np.any(np.isnan(rates)), \
        "neg_log_likelihood: NaN rate predictions found"

    assert np.all(rates >= 0), \
        "neg_log_likelihood: Negative rate predictions found"
    if (np.any(rates == 0)):
        rates[rates == 0] = 1e-9
    
    result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0)
    return np.sum(result)

def bits_per_spike(rates, spikes):
    """Computes bits per spike of rate predictions given spikes.
    Bits per spike is equal to the difference between the log-likelihoods (in base 2)
    of the rate predictions and the null model (i.e. predicting mean firing rate of each neuron)
    divided by the total number of spikes.

    Parameters
    ----------
    rates : np.ndarray
        3d numpy array containing rate predictions
    spikes : np.ndarray
        3d numpy array containing true spike counts
    
    Returns
    -------
    float
        Bits per spike of rate predictions
    """
    nll_model = neg_log_likelihood(rates, spikes)
    nll_null = neg_log_likelihood(np.tile(np.nanmean(spikes, axis=(0,1), keepdims=True), (spikes.shape[0], spikes.shape[1], 1)), spikes, zero_warning=False)
    return (nll_null - nll_model) / np.nansum(spikes) / np.log(2)

2. 训练准备

from torch.utils.data import DataLoader
from dataset import make_datasets, mask_batch
from model import SpatioTemporalNDT
from metric import bits_per_spike
import torch
from torch.optim import AdamW
from torch import nn

batch_size = 16
lr = 1e-3
train_dataset, val_dataset = make_datasets()
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)
trial_length = 160
neuron_num = 130
model = SpatioTemporalNDT(trial_length, neuron_num)
num_epochs = 50
optim = AdamW(model.parameters(), lr=lr)
log_interval = 20

3. debug测试

def param_num(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

def debug_test():
    spikes, heldout_spikes, forward_spikes = next(iter(train_dataloader))
    print(spikes.shape)             # [16, 120, 98]
    print(heldout_spikes.shape)     # [16, 120, 32]
    print(forward_spikes.shape)     # [16, 40, 130]
    masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)
    print(masked_spikes.shape)      # [16, 160, 130]
    print(labels.shape)             # [16, 160, 130]

    print(param_num(model))         # 256886
    loss, decoder_rates = model.forward(masked_spikes, labels)
    print(loss)                     # tensor(1.2356, grad_fn=<MeanBackward0>)
    print(decoder_rates.shape)      # torch.Size([16, 160, 130])

    val_loss, val_score = valid(val_dataloader, model)
    print(val_loss)
    print(val_score)

4. train-val函数

def train(model, dataloader, val_dataloader, num_epochs, optim):
    for epoch in range(num_epochs):
        print(f"--------- Epoch{epoch:2d} ----------")
        train_loss = []
        for i, (spikes, heldout_spikes, forward_spikes) in enumerate(dataloader):
            masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)
            loss, decoder_rates = model(masked_spikes, labels)
            optim.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 200.0)
            optim.step()
            with torch.no_grad():
                train_loss.append(loss.item())
            if i % log_interval == 0:
                print(f"Train loss: {sum(train_loss)/len(train_loss)}")
        val_loss, val_score = valid(val_dataloader, model)
        print(f"val loss: {float(val_loss)}")
        print(f"val score: {float(val_score)}")
        print()


def valid(val_dataloader, model):
    model.eval()
    pred_rates = []
    heldout_spikes_full = []
    loss_list = []
    with torch.no_grad():
        for spikes, heldout_spikes, forward_spikes in val_dataloader:
            no_mask_labels = spikes.clone()
            no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(heldout_spikes)], -1)
            no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(forward_spikes)], 1)
            no_mask_labels[:, :, -heldout_spikes.size(-1):] = -100 # unmasked_label
            no_mask_labels[:, -forward_spikes.size(1):,:] = -100 # unmasked_label
            spikes = torch.cat([spikes, torch.zeros_like(heldout_spikes)], -1)
            spikes = torch.cat([spikes, torch.zeros_like(forward_spikes)], 1)
            loss, batch_rates = model(spikes, no_mask_labels)
            pred_rates.append(batch_rates)
            heldout_spikes_full.append(heldout_spikes)
            loss_list.append(loss)

    heldout_spikes = torch.cat(heldout_spikes_full, dim=0)
    pred_rates = torch.cat(pred_rates, dim=0)
    eval_rates_heldout = torch.exp(pred_rates.clone()[:, :heldout_spikes.size(1), -heldout_spikes.size(-1):]).numpy()[()].astype('float')
    eval_spikes_heldout = heldout_spikes.clone().numpy()[()].astype('float')
    # print(eval_rates_heldout.shape)     # (270, 120, 32)
    # print(eval_spikes_heldout.shape)    # (270, 120, 32)
    return sum(loss_list), float(bits_per_spike(eval_rates_heldout, eval_spikes_heldout))

最后,开始训练:

(stndt) D:\STNDT>python main.py
--------- Epoch 0 ----------
Train loss: 1.2486777305603027
Train loss: 0.5138219218878519
Train loss: 0.32351083744589876
val loss: 0.8636534214019775
val score: -0.39136893422272767

--------- Epoch 1 ----------
Train loss: 0.09501783549785614
Train loss: 0.09383604036910194
Train loss: 0.09296295773692248
val loss: 0.8206770420074463
val score: -0.09666108663240561

--------- Epoch 2 ----------
Train loss: 0.09622671455144882
Train loss: 0.09049306774423235
Train loss: 0.08994600358532696
val loss: 0.812911331653595
val score: -0.04202061410637105

--------- Epoch 3 ----------
Train loss: 0.09225568175315857
Train loss: 0.09019481816462108
Train loss: 0.08970968806888999
val loss: 0.8099062442779541
val score: -0.019777008609723395

--------- Epoch 4 ----------
Train loss: 0.08371596038341522
Train loss: 0.08918796905449458
Train loss: 0.0894875490083927
val loss: 0.8083348274230957
val score: -0.008896993842432857

--------- Epoch 5 ----------
Train loss: 0.09019782394170761
Train loss: 0.08884035441137496
Train loss: 0.08963883395602064
val loss: 0.8072853088378906
val score: -0.0026569800293788507

--------- Epoch 6 ----------
Train loss: 0.09667835384607315
Train loss: 0.09060979953833989
Train loss: 0.08956735653848183
val loss: 0.8064565658569336
val score: 0.0003163842262874261

--------- Epoch 7 ----------
Train loss: 0.08744495362043381
Train loss: 0.08888665941499528
Train loss: 0.08930287855427439
val loss: 0.8058080077171326
val score: 0.005321093845270125

--------- Epoch 8 ----------
Train loss: 0.10221674293279648
Train loss: 0.09078312771660942
Train loss: 0.08951869806865366
val loss: 0.8044026494026184
val score: 0.007113516568588765

--------- Epoch 9 ----------
Train loss: 0.09160886704921722
Train loss: 0.08984803798652831
Train loss: 0.0897282888976539
val loss: 0.803226113319397
val score: 0.01217366049067505

--------- Epoch10 ----------
Train loss: 0.09165512025356293
Train loss: 0.08854220310846965
Train loss: 0.08920388268988307
val loss: 0.8014105558395386
val score: 0.015657932109121083

--------- Epoch11 ----------
Train loss: 0.07934647053480148
Train loss: 0.08873837547642845
Train loss: 0.08900632345821799
val loss: 0.7992606163024902
val score: 0.017361369978752348

--------- Epoch12 ----------
Train loss: 0.08641393482685089
Train loss: 0.0893486404702777
Train loss: 0.08927923113834567
val loss: 0.7964036464691162
val score: 0.026846927269458674

--------- Epoch13 ----------
Train loss: 0.08859497308731079
Train loss: 0.08794442635206949
Train loss: 0.08938420000599652
val loss: 0.7929846048355103
val score: 0.033583528051411037

--------- Epoch14 ----------
Train loss: 0.08901184052228928
Train loss: 0.08875668652000882
Train loss: 0.08939630665430208
val loss: 0.7878748178482056
val score: 0.04465469491549107

--------- Epoch15 ----------
Train loss: 0.09487541764974594
Train loss: 0.08885077848320916
Train loss: 0.08909488651083737
val loss: 0.7851467728614807
val score: 0.046395409621300066

--------- Epoch16 ----------
Train loss: 0.0839885026216507
Train loss: 0.08959413000515529
Train loss: 0.08932711874566428
val loss: 0.7806612253189087
val score: 0.05012596379845563

--------- Epoch17 ----------
Train loss: 0.09544813632965088
Train loss: 0.08826960552306402
Train loss: 0.0890249778948179
val loss: 0.7787002325057983
val score: 0.05084565441331739

--------- Epoch18 ----------
Train loss: 0.09305278211832047
Train loss: 0.08740198683171045
Train loss: 0.08877205539767336
val loss: 0.7735776305198669
val score: 0.06808317309022775

--------- Epoch19 ----------
Train loss: 0.08946727961301804
Train loss: 0.0880857486100424
Train loss: 0.08832225821367125
val loss: 0.7722467184066772
val score: 0.0741929715804975

--------- Epoch20 ----------
Train loss: 0.09155283123254776
Train loss: 0.08762263329256148
Train loss: 0.08867140041618812
val loss: 0.774036705493927
val score: 0.06465988606612133

--------- Epoch21 ----------
Train loss: 0.08425123244524002
Train loss: 0.08848933414334342
Train loss: 0.08806171540806933
val loss: 0.7706096768379211
val score: 0.06233272968330965

--------- Epoch22 ----------
Train loss: 0.08672144263982773
Train loss: 0.08736556342669896
Train loss: 0.08800865782470238
val loss: 0.7690156698226929
val score: 0.07570956489538153

--------- Epoch23 ----------
Train loss: 0.09086063504219055
Train loss: 0.0895571896717662
Train loss: 0.08793148053128545
val loss: 0.7725724577903748
val score: 0.045295719065139656

--------- Epoch24 ----------
Train loss: 0.08895140141248703
Train loss: 0.08862598595165071
Train loss: 0.08853605389595032
val loss: 0.7674567103385925
val score: 0.07400126493414798

--------- Epoch25 ----------
Train loss: 0.08059882372617722
Train loss: 0.08788907066697166
Train loss: 0.08830737322568893
val loss: 0.7654385566711426
val score: 0.0783971076192251

--------- Epoch26 ----------
Train loss: 0.0904078260064125
Train loss: 0.08821353883970351
Train loss: 0.08813101125926506
val loss: 0.7648967504501343
val score: 0.06579874206738114

--------- Epoch27 ----------
Train loss: 0.0888797715306282
Train loss: 0.08781595457167853
Train loss: 0.08853465282335514
val loss: 0.765023946762085
val score: 0.06403537205845905

--------- Epoch28 ----------
Train loss: 0.0925334170460701
Train loss: 0.08814156835987455
Train loss: 0.08763645026015073
val loss: 0.7604566216468811
val score: 0.08386773786224676

--------- Epoch29 ----------
Train loss: 0.09102518111467361
Train loss: 0.08881006035066787
Train loss: 0.08800200536483671
val loss: 0.7639309167861938
val score: 0.05987701272594979

--------- Epoch30 ----------
Train loss: 0.08757702261209488
Train loss: 0.08790529945066997
Train loss: 0.08796896276677527
val loss: 0.7679344415664673
val score: 0.04645880716520806

--------- Epoch31 ----------
Train loss: 0.09563669562339783
Train loss: 0.08776313385793141
Train loss: 0.08768014010132813
val loss: 0.7532508969306946
val score: 0.09419951931221196

--------- Epoch32 ----------
Train loss: 0.08262639492750168
Train loss: 0.08920836945374806
Train loss: 0.08818964242208295
val loss: 0.7534663081169128
val score: 0.07980706821661744

--------- Epoch33 ----------
Train loss: 0.09010934829711914
Train loss: 0.08798151392312277
Train loss: 0.08814984251086305
val loss: 0.7573298215866089
val score: 0.0587445179781999

--------- Epoch34 ----------
Train loss: 0.09029105305671692
Train loss: 0.08793160106454577
Train loss: 0.087826013383342
val loss: 0.7541366219520569
val score: 0.04576204364697583

--------- Epoch35 ----------
Train loss: 0.09183177351951599
Train loss: 0.08813220936627615
Train loss: 0.08824214902592868
val loss: 0.7545167803764343
val score: 0.043795136749962035

--------- Epoch36 ----------
Train loss: 0.08738738298416138
Train loss: 0.08769806651842027
Train loss: 0.08801802520344897
val loss: 0.7475957870483398
val score: 0.07046052509968409

--------- Epoch37 ----------
Train loss: 0.08695636689662933
Train loss: 0.08928513243084862
Train loss: 0.08794533206922252
val loss: 0.7405006885528564
val score: 0.08250606459379788

--------- Epoch38 ----------
Train loss: 0.08741921186447144
Train loss: 0.08701477554582414
Train loss: 0.08772314776007722
val loss: 0.7421612739562988
val score: 0.07261544623998699

--------- Epoch39 ----------
Train loss: 0.08897516131401062
Train loss: 0.08884722207273756
Train loss: 0.08827457195375024
val loss: 0.7383261919021606
val score: 0.05041364027920663

--------- Epoch40 ----------
Train loss: 0.08877569437026978
Train loss: 0.08783218938679922
Train loss: 0.08838088319795888
val loss: 0.7311040759086609
val score: 0.05160266134263263

--------- Epoch41 ----------
Train loss: 0.0751330778002739
Train loss: 0.0872439131850288
Train loss: 0.08815818952351082
val loss: 0.723595917224884
val score: 0.08080731948303856

--------- Epoch42 ----------
Train loss: 0.09519665688276291
Train loss: 0.0866984451810519
Train loss: 0.08742059876279133
val loss: 0.7205336689949036
val score: 0.08327377202054256

--------- Epoch43 ----------
Train loss: 0.08966871351003647
Train loss: 0.08703825693754923
Train loss: 0.08704596176380064
val loss: 0.7158994078636169
val score: 0.05753987849499046

--------- Epoch44 ----------
Train loss: 0.08914705365896225
Train loss: 0.08722686128956932
Train loss: 0.08729714445951509
val loss: 0.7021420001983643
val score: 0.08133226152944593

--------- Epoch45 ----------
Train loss: 0.08485537022352219
Train loss: 0.08770599854843956
Train loss: 0.08782925693000235
val loss: 0.705651044845581
val score: 0.07325790592903407

--------- Epoch46 ----------
Train loss: 0.08972616493701935
Train loss: 0.088348921920572
Train loss: 0.08801035510330665
val loss: 0.6982176303863525
val score: 0.06009563284716213

--------- Epoch47 ----------
Train loss: 0.08506552129983902
Train loss: 0.08846274834303629
Train loss: 0.08772453265946085
val loss: 0.684754490852356
val score: 0.10142577749520322

--------- Epoch48 ----------
Train loss: 0.08494629710912704
Train loss: 0.08716638279812676
Train loss: 0.08738453831614518
val loss: 0.6825719475746155
val score: 0.087609587353269

--------- Epoch49 ----------
Train loss: 0.08093467354774475
Train loss: 0.08778195899157297
Train loss: 0.08736045422350489
val loss: 0.6823106408119202
val score: 0.06519610685639747

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值