TFN-train

本文介绍了使用TensorFlow Network (TFN)模型在MOSI数据集上的预处理、训练过程以及性能评估。通过设置参数,如epochs、batch_size和patience,实现模型优化并计算精度、召回率等指标。关键步骤包括数据标准化、模型初始化和验证,展示了如何处理NaN值和寻找最佳模型。
摘要由CSDN通过智能技术生成

train.py

from __future__ import print_function
from model import TFN
from utils import MultimodalDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
import numpy as np


def preprocess(options):
    # parse the input args
    dataset = options['dataset']
    epochs = options['epochs']
    model_path = options['model_path']
    max_len = options['max_len']

    # prepare the paths for storing models
    model_path = os.path.join(
        model_path, "tfn.pt")
    print("Temp location for saving model: {}".format(model_path))

    # prepare the datasets
    print("Currently using {} dataset.".format(dataset))
    mosi = MultimodalDataset(dataset, max_len=max_len)
    train_set, valid_set, test_set = mosi.train_set, mosi.valid_set, mosi.test_set

    audio_dim = train_set[0][0].shape[1]
    print("Audio feature dimension is: {}".format(audio_dim))
    visual_dim = train_set[0][1].shape[1]
    print("Visual feature dimension is: {}".format(visual_dim))
    text_dim = train_set[0][2].shape[1]
    print("Text feature dimension is: {}".format(text_dim))
    input_dims = (audio_dim, visual_dim, text_dim)

    # normalize the visual features
    visual_max = np.max(np.max(np.abs(train_set.visual), axis=0), axis=0)
    visual_max[visual_max==0] = 1
    train_set.visual = train_set.visual / visual_max
    valid_set.visual = valid_set.visual / visual_max
    test_set.visual = test_set.visual / visual_max

    # for visual and audio modality, we average across time
    # here the original data has shape (max_len, num_examples, feature_dim)
    # after averaging they become (1, num_examples, feature_dim)
    train_set.visual = np.mean(train_set.visual, axis=0, keepdims=True)
    train_set.audio = np.mean(train_set.audio, axis=0, keepdims=True)
    valid_set.visual = np.mean(valid_set.visual, axis=0, keepdims=True)
    valid_set.audio = np.mean(valid_set.audio, axis=0, keepdims=True)
    test_set.visual = np.mean(test_set.visual, axis=0, keepdims=True)
    test_set.audio = np.mean(test_set.audio, axis=0, keepdims=True)

    # remove possible NaN values
    train_set.visual[train_set.visual != train_set.visual] = 0
    valid_set.visual[valid_set.visual != valid_set.visual] = 0
    test_set.visual[test_set.visual != test_set.visual] = 0

    train_set.audio[train_set.audio != train_set.audio] = 0
    valid_set.audio[valid_set.audio != valid_set.audio] = 0
    test_set.audio[test_set.audio != test_set.audio] = 0

    return train_set, valid_set, test_set, input_dims

def display(test_loss, test_binacc, test_precision, test_recall, test_f1, test_septacc, test_corr):
    print("MAE on test set is {}".format(test_loss))
    print("Binary accuracy on test set is {}".format(test_binacc))
    print("Precision on test set is {}".format(test_precision))
    print("Recall on test set is {}".format(test_recall))
    print("F1 score on test set is {}".format(test_f1))
    print("Seven-class accuracy on test set is {}".format(test_septacc))
    print("Correlation w.r.t human evaluation on test set is {}".format(test_corr))

def main(options):
    DTYPE = torch.FloatTensor
    train_set, valid_set, test_set, input_dims = preprocess(options)

    model = TFN(input_dims, (4, 16, 128), 64, (0.3, 0.3, 0.3, 0.3), 32)
    if options['cuda']:
        model = model.cuda()
        DTYPE = torch.cuda.FloatTensor
    print("Model initialized")
    criterion = nn.L1Loss(size_average=False)
    optimizer = optim.Adam(list(model.parameters())[2:]) # don't optimize the first 2 params, they should be fixed (output_range and shift)
    
    # setup training
    complete = True
    min_valid_loss = float('Inf')
    batch_sz = options['batch_size']
    patience = options['patience']
    epochs = options['epochs']
    model_path = options['model_path']
    train_iterator = DataLoader(train_set, batch_size=batch_sz, num_workers=4, shuffle=True)
    valid_iterator = DataLoader(valid_set, batch_size=len(valid_set), num_workers=4, shuffle=True)
    test_iterator = DataLoader(test_set, batch_size=len(test_set), num_workers=4, shuffle=True)
    curr_patience = patience
    for e in range(epochs):
        model.train()
        model.zero_grad()
        train_loss = 0.0
        for batch in train_iterator:
            model.zero_grad()

            # the provided data has format [batch_size, seq_len, feature_dim] or [batch_size, 1, feature_dim]
            x = batch[:-1]
            x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
            x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
            x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
            y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
            output = model(x_a, x_v, x_t)
            loss = criterion(output, y)
            loss.backward()
            train_loss += loss.data[0] / len(train_set)
            optimizer.step()

        print("Epoch {} complete! Average Training loss: {}".format(e, train_loss))

        # Terminate the training process if run into NaN
        if np.isnan(train_loss):
            print("Training got into NaN values...\n\n")
            complete = False
            break

        # On validation set we don't have to compute metrics other than MAE and accuracy
        model.eval()
        for batch in valid_iterator:
            x = batch[:-1]
            x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
            x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
            x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
            y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
            output = model(x_a, x_v, x_t)
            valid_loss = criterion(output, y)
        output_valid = output.cpu().data.numpy().reshape(-1)
        y = y.cpu().data.numpy().reshape(-1)

        if np.isnan(valid_loss.data[0]):
            print("Training got into NaN values...\n\n")
            complete = False
            break

        valid_binacc = accuracy_score(output_valid>=0, y>=0)

        print("Validation loss is: {}".format(valid_loss.data[0] / len(valid_set)))
        print("Validation binary accuracy is: {}".format(valid_binacc))

        if (valid_loss.data[0] < min_valid_loss):
            curr_patience = patience
            min_valid_loss = valid_loss.data[0]
            torch.save(model, model_path)
            print("Found new best model, saving to disk...")
        else:
            curr_patience -= 1
        
        if curr_patience <= 0:
            break
        print("\n\n")

    if complete:
        
        best_model = torch.load(model_path)
        best_model.eval()
        for batch in test_iterator:
            x = batch[:-1]
            x_a = Variable(x[0].float().type(DTYPE), requires_grad=False).squeeze()
            x_v = Variable(x[1].float().type(DTYPE), requires_grad=False).squeeze()
            x_t = Variable(x[2].float().type(DTYPE), requires_grad=False)
            y = Variable(batch[-1].view(-1, 1).float().type(DTYPE), requires_grad=False)
            output_test = best_model(x_a, x_v, x_t)
            loss_test = criterion(output_test, y)
            test_loss = loss_test.data[0]
        output_test = output_test.cpu().data.numpy().reshape(-1)
        y = y.cpu().data.numpy().reshape(-1)

        test_binacc = accuracy_score(output_test>=0, y>=0)
        test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(y>=0, output_test>=0, average='binary')
        test_septacc = (output_test.round() == y.round()).mean()

        # compute the correlation between true and predicted scores
        test_corr = np.corrcoef([output_test, y])[0][1]  # corrcoef returns a matrix
        test_loss = test_loss / len(test_set)

        display(test_loss, test_binacc, test_precision, test_recall, test_f1, test_septacc, test_corr)
    return

if __name__ == "__main__":
    OPTIONS = argparse.ArgumentParser()
    OPTIONS.add_argument('--dataset', dest='dataset',
                         type=str, default='MOSI')
    OPTIONS.add_argument('--epochs', dest='epochs', type=int, default=50)
    OPTIONS.add_argument('--batch_size', dest='batch_size', type=int, default=32)
    OPTIONS.add_argument('--patience', dest='patience', type=int, default=20)
    OPTIONS.add_argument('--cuda', dest='cuda', type=bool, default=False)
    OPTIONS.add_argument('--model_path', dest='model_path',
                         type=str, default='models')
    OPTIONS.add_argument('--max_len', dest='max_len', type=int, default=20)
    PARAMS = vars(OPTIONS.parse_args())
    main(PARAMS)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值