Simple-STNDT使用Transformer进行Spike信号的表征学习(二)模型结构

1. 位置编码

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100

class PositionalEncoding(nn.Module):
    def __init__(self, trial_length, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(trial_length, d_model)
        position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

1.2 EncoderLayer

model.py
核心编码层,加入了将空间注意力编码

class STNTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, d_model_s, num_heads=2,  dim_feedforward=128, dropout=0.1, 
                 activation='relu'):
        super().__init__(
            d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation
        )
        self.num_heads = num_heads
        self.num_input = d_model
        self.d_model_s = d_model_s      # d_model_s: 时间步数(例如 160), 用于空间自注意力
        self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)
        self.spatial_norm1 = nn.LayerNorm(d_model_s)
        self.ts_norm1 = nn.LayerNorm(d_model)
        self.ts_norm2 = nn.LayerNorm(d_model)
        self.ts_linear1 = nn.Linear(d_model, dim_feedforward)
        self.ts_linear2 = nn.Linear(dim_feedforward, d_model)
        self.ts_dropout1 = nn.Dropout(dropout)
        self.ts_dropout2 = nn.Dropout(dropout)
        self.ts_dropout3 = nn.Dropout(dropout)
    
    def attend(self, src, context_mask=None, **kwargs):
        attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    def spatial_attend(self, src, context_mask=None, **kwargs):
        r"""
        Attends over spatial dimension
        Args:
            src: spatiotemporal neural population input
            context_mask: spatial context mask
        Returns:
            spatiotemporal neural population activity transformed by spatial attention
        """
        attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    
    def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):
        # temporal
        residual = src
        src = self.norm1(src)
        t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = residual + self.dropout1(t_out)
        residual = src
        src = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = residual + self.dropout2(src2)

        # spatial
        spatial_src = self.spatial_norm1(spatial_src)
        spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)

        # spatio-temporal feature mixture
        ts_residual = src
        src = self.ts_norm1(src)
        ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)
        ts_out = ts_residual + self.ts_dropout1(ts_out)
        ts_residual = ts_out
        ts_out = self.ts_norm2(ts_out)
        ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))
        ts_out = ts_residual + self.ts_dropout3(ts_out)
        
        return ts_out

1.3 Encoder

model.py

class STNTransformerEncoder(TransformerEncoder):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__(encoder_layer, num_layers, norm)
    
    def forward(self, src, spatial_src, mask=None, spatial_mask=None):
        for i, mod in enumerate(self.layers):
            if i == 0:
                src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)
            else:
                src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)
        if self.norm is not None:
            src = self.norm(src)
        return src

1.4 STNDT

model.py

class SpatioTemporalNDT(nn.Module):
    def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, 
                 dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,
                 enc_heads=2,  enc_dff=128, enc_drop=0.1
                 ) -> None:
        super().__init__()

        self.src_mask = None
        self.num_input = num_neurons
        self.num_spatial_input = trial_length
        self.embedder = nn.Identity()
        self.spatial_embedder = nn.Identity()
        self.scale = math.sqrt(num_neurons)
        self.spatial_scale = math.sqrt(trial_length)
        self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)
        self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)
        
        self.projector = nn.Identity()
        self.spatial_projector = nn.Identity()
        self.n_views = 2
        self.temperature = temperature
        self.contrast_lambda = c_lambda
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.mse = nn.MSELoss(reduction='mean')

        encoder_layer =STNTransformerEncoderLayer(
            d_model=self.num_input,
            d_model_s=self.num_spatial_input, 
            num_heads=enc_heads,
            dim_feedforward=enc_dff,
            dropout=enc_drop
        )
        self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))

        self.rate_dropout = nn.Dropout(dropout)
        self.src_decoder = nn.Linear(num_neurons, self.num_input)
        self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)
    
    def _get_mask(self, src, do_convert=True):
        if self.src_mask is not None:
            return self.src_mask
        size = src.size(0)
        context_forward = 13
        context_backward = 79
        mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)
        back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)
        mask = mask & back_mask
        mask = mask.float()
        mask = binary_mask_to_attn_mask(mask)
        self.src_mask = mask
        return self.src_mask
    
    def forward(self, src: torch.Tensor, mask_labels: torch.Tensor):
        src = src.float()
        spatial_src = src.permute(2,0,1)
        spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scale
        spatial_src = self.spatial_pos_encoder(spatial_src)
        src = src.permute(1,0,2)
        src = self.embedder(src) * self.scale
        src = self.src_pos_encoder(src)
        src_mask = self._get_mask(src)
        spatial_src_mask = None
        encoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)
        encoder_output = self.rate_dropout(encoder_output)
        decoder_output = self.src_decoder(encoder_output)
        
        decoder_rates = decoder_output.permute(1, 0, 2)
        decoder_loss = self.classifier(decoder_rates, mask_labels)
        masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]
        masked_decoder_loss = masked_decoder_loss.mean()

        return masked_decoder_loss, decoder_rates


def binary_mask_to_attn_mask(x):
    return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值