from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform, xavier_normal, orthogonal
classSubNet(nn.Module):'''
The subnetwork that is used in TFN for video and audio in the pre-fusion stage
'''def__init__(self, in_size, hidden_size, dropout):'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
dropout: dropout probability
Output:
(return value in forward) a tensor of shape (batch_size, hidden_size)
'''super(SubNet, self).__init__()
self.norm = nn.BatchNorm1d(in_size)
self.drop = nn.Dropout(p=dropout)
self.linear_1 = nn.Linear(in_size, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, hidden_size)defforward(self, x):'''
Args:
x: tensor of shape (batch_size, in_size)
'''
normed = self.norm(x)
dropped = self.drop(normed)
y_1 = F.relu(self.linear_1(dropped))
y_2 = F.relu(self.linear_2(y_1))
y_3 = F.relu(self.linear_3(y_2))return y_3
classTextSubNet(nn.Module):'''
The LSTM-based subnetwork that is used in TFN for text
'''def__init__(self, in_size, hidden_size, out_size, num_layers=1, dropout=0.2, bidirectional=False):'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
num_layers: specify the number of layers of LSTMs.
dropout: dropout probability
bidirectional: specify usage of bidirectional LSTM
Output:
(return value in forward) a tensor of shape (batch_size, out_size)
'''super(TextSubNet, self).__init__()
self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.linear_1 = nn.Linear(hidden_size, out_size)defforward(self, x):'''
Args:
x: tensor of shape (batch_size, sequence_len, in_size)
'''
_, final_states = self.rnn(x)
h = self.dropout(final_states[0].squeeze())
y_1 = self.linear_1(h)return y_1
classTFN(nn.Module):'''
Implements the Tensor Fusion Networks for multimodal sentiment analysis as is described in:
Zadeh, Amir, et al. "Tensor fusion network for multimodal sentiment analysis." EMNLP 2017 Oral.
'''def__init__(self, input_dims, hidden_dims, text_out, dropouts, post_fusion_dim):'''
Args:
input_dims - a length-3 tuple, contains (audio_dim, video_dim, text_dim)
hidden_dims - another length-3 tuple, similar to input_dims
text_out - int, specifying the resulting dimensions of the text subnetwork
dropouts - a length-4 tuple, contains (audio_dropout, video_dropout, text_dropout, post_fusion_dropout)
post_fusion_dim - int, specifying the size of the sub-networks after tensorfusion
Output:
(return value in forward) a scalar value between -3 and 3
'''super(TFN, self).__init__()# dimensions are specified in the order of audio, video and text
self.audio_in = input_dims[0]
self.video_in = input_dims[1]
self.text_in = input_dims[2]
self.audio_hidden = hidden_dims[0]
self.video_hidden = hidden_dims[1]
self.text_hidden = hidden_dims[2]
self.text_out= text_out
self.post_fusion_dim = post_fusion_dim
self.audio_prob = dropouts[0]
self.video_prob = dropouts[1]
self.text_prob = dropouts[2]
self.post_fusion_prob = dropouts[3]# define the pre-fusion subnetworks
self.audio_subnet = SubNet(self.audio_in, self.audio_hidden, self.audio_prob)
self.video_subnet = SubNet(self.video_in, self.video_hidden, self.video_prob)
self.text_subnet = TextSubNet(self.text_in, self.text_hidden, self.text_out, dropout=self.text_prob)# define the post_fusion layers
self.post_fusion_dropout = nn.Dropout(p=self.post_fusion_prob)
self.post_fusion_layer_1 = nn.Linear((self.text_out +1)*(self.video_hidden +1)*(self.audio_hidden +1), self.post_fusion_dim)
self.post_fusion_layer_2 = nn.Linear(self.post_fusion_dim, self.post_fusion_dim)
self.post_fusion_layer_3 = nn.Linear(self.post_fusion_dim,1)# in TFN we are doing a regression with constrained output range: (-3, 3), hence we'll apply sigmoid to output# shrink it to (0, 1), and scale\shift it back to range (-3, 3)
self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)defforward(self, audio_x, video_x, text_x):'''
Args:
audio_x: tensor of shape (batch_size, audio_in)
video_x: tensor of shape (batch_size, video_in)
text_x: tensor of shape (batch_size, sequence_len, text_in)
'''
audio_h = self.audio_subnet(audio_x)
video_h = self.video_subnet(video_x)
text_h = self.text_subnet(text_x)
batch_size = audio_h.data.shape[0]# next we perform "tensor fusion", which is essentially appending 1s to the tensors and take Kronecker productif audio_h.is_cuda:
DTYPE = torch.cuda.FloatTensor
else:
DTYPE = torch.FloatTensor
_audio_h = torch.cat((Variable(torch.ones(batch_size,1).type(DTYPE), requires_grad=False), audio_h), dim=1)
_video_h = torch.cat((Variable(torch.ones(batch_size,1).type(DTYPE), requires_grad=False), video_h), dim=1)
_text_h = torch.cat((Variable(torch.ones(batch_size,1).type(DTYPE), requires_grad=False), text_h), dim=1)# _audio_h has shape (batch_size, audio_in + 1), _video_h has shape (batch_size, _video_in + 1)# we want to perform outer product between the two batch, hence we unsqueenze them to get# (batch_size, audio_in + 1, 1) X (batch_size, 1, video_in + 1)# fusion_tensor will have shape (batch_size, audio_in + 1, video_in + 1)
fusion_tensor = torch.bmm(_audio_h.unsqueeze(2), _video_h.unsqueeze(1))# next we do kronecker product between fusion_tensor and _text_h. This is even trickier# we have to reshape the fusion tensor during the computation# in the end we don't keep the 3-D tensor, instead we flatten it
fusion_tensor = fusion_tensor.view(-1,(self.audio_hidden +1)*(self.video_hidden +1),1)
fusion_tensor = torch.bmm(fusion_tensor, _text_h.unsqueeze(1)).view(batch_size,-1)
post_fusion_dropped = self.post_fusion_dropout(fusion_tensor)
post_fusion_y_1 = F.relu(self.post_fusion_layer_1(post_fusion_dropped))
post_fusion_y_2 = F.relu(self.post_fusion_layer_2(post_fusion_y_1))
post_fusion_y_3 = F.sigmoid(self.post_fusion_layer_3(post_fusion_y_2))
output = post_fusion_y_3 * self.output_range + self.output_shift
return output