import copy
from typing import Optional, Any
import torch
from torch import nn,Tensor
import torch.nn.functional as F
from torch.nn import Module
from torch.nn import MultiheadAttention
from torch.nn import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn import Dropout
classTIF(nn.Module):def__init__(self, in_dim):super(TIF, self).__init__()
self.chanel_in = in_dim
self.conv1=nn.Sequential(
nn.ConvTranspose2d(in_dim*2, in_dim, kernel_size=1, stride=1),)
self.conv2=nn.Sequential(
nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=3),
nn.BatchNorm2d(in_dim),
nn.ReLU(inplace=True),)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.linear1=nn.Conv2d(in_dim, in_dim //6,1, bias=False)
self.linear2=nn.Conv2d(in_dim //6, in_dim,1, bias=False)
self.gamma = nn.Parameter(torch.zeros(1))
self.activation=nn.ReLU(inplace=True)
self.dropout=nn.Dropout()defforward(self, x,y):
ww=self.linear2(self.dropout(self.activation(self.linear1(self.avg_pool(self.conv2(y))))))
weight=self.conv1(torch.cat((x,y),1))*ww
return x+self.gamma*weight*x
classTransformertime(Module):def__init__(self, d_model:int=512, nhead:int=8, num_encoder_layers:int=6,
num_decoder_layers:int=6, dim_feedforward:int=384, dropout:float=0.1,
activation:str="relu", custom_encoder: Optional[Any]=None, custom_decoder: Optional[Any]=None)->None:super(Transformertime, self).__init__()if custom_encoder isnotNone:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
encoder_norm = nn.LayerNorm(d_model)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)if custom_decoder isnotNone:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
defforward(self, src: Tensor, srcc: Tensor,tgt: Tensor, src_mask: Optional[Tensor]=None, tgt_mask: Optional[Tensor]=None,
memory_mask: Optional[Tensor]=None, src_key_padding_mask: Optional[Tensor]=None,
tgt_key_padding_mask: Optional[Tensor]=None, memory_key_padding_mask: Optional[Tensor]=None)-> Tensor:if src.size(1)!= tgt.size(1):raise RuntimeError("the batch number of src and tgt must be equal")if src.size(2)!= self.d_model or tgt.size(2)!= self.d_model:raise RuntimeError("the feature number of src and tgt must be equal to d_model")
memory = self.encoder(src,srcc, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(tgt,srcc, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)return memory,output
defgenerate_square_subsequent_mask(self, sz:int)-> Tensor:r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask =(torch.triu(torch.ones(sz, sz))==1).transpose(0,1)
mask = mask.float().masked_fill(mask ==0,float('-inf')).masked_fill(mask ==1,float(0.0))return mask
def_reset_parameters(self):r"""Initiate parameters in the transformer model."""for p in self.parameters():if p.dim()>1:
xavier_uniform_(p)classTransformerEncoder(Module):r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ =['norm']def__init__(self, encoder_layer, num_layers, norm=None):super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
defforward(self, src: Tensor,srcc: Tensor, mask: Optional[Tensor]=None, src_key_padding_mask: Optional[Tensor]=None)-> Tensor:r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
output = src
for mod in self.layers:
output = mod(output,srcc, src_mask=mask, src_key_padding_mask=src_key_padding_mask)if self.norm isnotNone:
output = self.norm(output)return output
classTransformerDecoder(Module):r"""TransformerDecoder is a stack of N decoder layers
Args:
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
num_layers: the number of sub-decoder-layers in the decoder (required).
norm: the layer normalization component (optional).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
__constants__ =['norm']def__init__(self, decoder_layer, num_layers, norm=None):super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
defforward(self, tgt: Tensor, srcc: Tensor, memory: Tensor, tgt_mask: Optional[Tensor]=None,
memory_mask: Optional[Tensor]=None, tgt_key_padding_mask: Optional[Tensor]=None,
memory_key_padding_mask: Optional[Tensor]=None)-> Tensor:r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
output = tgt
for mod in self.layers:
output = mod(output, srcc,memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)if self.norm isnotNone:
output = self.norm(output)return output
classTransformerEncoderLayer(Module):r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=384).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""def__init__(self, d_model, nhead, dim_feedforward=384, dropout=0.1, activation="relu"):super(TransformerEncoderLayer, self).__init__()
self.self_attn1 = MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn3 = MultiheadAttention(d_model, nhead, dropout=dropout)
channel=dim_feedforward//2
self.modulation=TIF(channel)
self.cross_attn = MultiheadAttention(d_model, nhead, dropout=dropout)# Implementation of Feedforward model
self.norm0 = nn.LayerNorm(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
self.dropout3 = Dropout(dropout)
self.activation = _get_activation_fn(activation)def__setstate__(self, state):if'activation'notin state:
state['activation']= F.relu
super(TransformerEncoderLayer, self).__setstate__(state)defforward(self, src: Tensor,srcc: Tensor, src_mask: Optional[Tensor]=None, src_key_padding_mask: Optional[Tensor]=None)-> Tensor:
b,c,s=src.permute(1,2,0).size()
src1 = self.self_attn1(srcc, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
srcs1 = src + self.dropout1(src1)
srcs1 = self.norm1(srcs1)
src2 = self.self_attn2(srcs1, srcs1, srcs1, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
srcs2 = srcs1 + self.dropout2(src2)
srcs2 = self.norm2(srcs2)
src=self.modulation(srcs2.view(b,c,int(s**0.5),int(s**0.5))\
,srcs1.contiguous().view(b,c,int(s**0.5),int(s**0.5))).view(b,c,-1).permute(2,0,1)
src2 = self.self_attn3(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
srcs1 = src + self.dropout3(src2)
srcs1 = self.norm3(srcs1)return srcs1
classTransformerDecoderLayer(Module):r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = decoder_layer(tgt, memory)
"""def__init__(self, d_model, nhead, dim_feedforward=384, dropout=0.1, activation="relu"):super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn1 = MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.norm4 = nn.LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
self.dropout4 = Dropout(dropout)
self.activation = _get_activation_fn(activation)def__setstate__(self, state):if'activation'notin state:
state['activation']= F.relu
super(TransformerDecoderLayer, self).__setstate__(state)defforward(self, tgt: Tensor, srcc: Tensor, memory: Tensor, tgt_mask: Optional[Tensor]=None, memory_mask: Optional[Tensor]=None,
tgt_key_padding_mask: Optional[Tensor]=None, memory_key_padding_mask: Optional[Tensor]=None)-> Tensor:r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt12 = self.multihead_attn1(tgt, memory, memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt12)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2)
tgt = self.norm4(tgt)return tgt
def_get_clones(module, N):return ModuleList([copy.deepcopy(module)for i inrange(N)])def_get_activation_fn(activation):if activation =="relu":return F.relu
elif activation =="gelu":return F.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))