Social LSTM源码阅读

主训练函数

主要部分

对于每一个序列(20个帧,每帧的人数,每个人的数据 例如20x3x3:
拿到x_seq ,_ , d_seq, numPedsList_seq, PedsList_seq
通过定义的函数dataloader.convert_proper_array把x_seq变换成适合输入的形式,具体来说只是把PedId-> 0 1 2 数组下标的形式
把x_seq的绝对坐标转换成相对坐标 相同的PedId的所有坐标 - 第一次出现的位置

            for sequence in range(dataloader.batch_size):
                # Get the data corresponding to the current sequence
                x_seq ,_ , d_seq, numPedsList_seq, PedsList_seq = x[sequence], y[sequence], d[sequence], numPedsList[sequence], PedsList[sequence]
                #dense vector creation  返回20个序列中(20,maxPed,2) 和 字典look_up pedId :range(len(maxPed))
                x_seq, lookup_seq = dataloader.convert_proper_array(x_seq, numPedsList_seq, PedsList_seq)
                target_id_values = x_seq[0][lookup_seq[target_id], 0:2]
                #grid mask calculation and storage depending on grid parameter
                if(args.grid):
                    # TODO:这里改了
                    if(epoch == 0):
                        grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq,args.neighborhood_size, args.grid_size, args.use_cuda)
                        grids[dataloader.dataset_pointer].append(grid_seq)
                    else:
                        grid_seq = grids[dataloader.dataset_pointer][(num_batch*dataloader.batch_size)+sequence]
                else:
                    grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq,args.neighborhood_size, args.grid_size, args.use_cuda)
                # vectorize trajectories in sequence    对于每个行人,其所有位置减去第一时刻位置得到相对位置
                x_seq, _ = vectorize_seq(x_seq, PedsList_seq, lookup_seq)        

下面还是在for sequence in range(dataloader.batch_size):这个循环中
numNodes:定义每帧个序列中所有行人数
cell_states:(numNodes, args.rnn_size)
将x_seq 和 grid_seq 输入到LstmCell模型中,下面重点讲解模型中对这些数据的处理

if args.use_cuda:                    
    x_seq = x_seq.cuda()
#number of peds in this sequence per frame
numNodes = len(lookup_seq)


hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
if args.use_cuda:                    
    hidden_states = hidden_states.cuda()

cell_states = Variable(torch.zeros(numNodes, args.rnn_size))
if args.use_cuda:                    
    cell_states = cell_states.cuda()

# Zero out gradients
net.zero_grad()
optimizer.zero_grad()


# Forward prop 这个下面重点讲一下
outputs, _, _ = net(x_seq, grid_seq, hidden_states, cell_states, PedsList_seq,numPedsList_seq ,dataloader, lookup_seq)
# Compute loss
loss = Gaussian2DLikelihood(outputs, x_seq, PedsList_seq, lookup_seq)
loss_batch += loss.item()

# Compute gradients
loss.backward()


# Clip gradients
torch.nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)

# Update parameters

重点讲一下socialLstm模型(代码是简化后的)

        for framenum,frame in enumerate(input_data): #enumerate()是将参数分成 (index,data)

           #计算社交张量,具体细节看下图(3x2048)
            social_tensor = self.getSocialTensor(grid_current, hidden_states_current)
            # Embed inputs (3x64) = (3x2)x(2,64)   input_embedding_layer(in_features=2,out_features=64)64也是Embedding_size=64
            input_embedded = self.dropout(self.relu(self.input_embedding_layer(nodes_current)))
            # Embed the social tensor(3x64) = (3x2048)x(2048x64)
            tensor_embedded = self.dropout(self.relu(self.tensor_embedding_layer(social_tensor)))
            # Concat input (3x128)
            concat_embedded = torch.cat((input_embedded, tensor_embedded), 1)
            # Compute the output 把20个帧数据都存到outputs中  并且把(3x128)通过输出层 得到(3x5)= (3x128)x(128,output_size)
            #当一个sequence结束后 outputs的形状是(60x5) 20个(3x5)
            outputs[framenum*numNodes + corr_index.data] = self.output_layer(h_nodes)
            
			#经过一个LSTM得到输出的隐藏状态和细胞状态
            if not self.gru:
                # One-step of the LSTM
                h_nodes, c_nodes = self.cell(concat_embedded, (hidden_states_current, cell_states_current))
            else:
                h_nodes = self.cell(concat_embedded, (hidden_states_current))
                
            # Update hidden and cell states 下一帧的hidden和cell的输入是
            hidden_states[corr_index.data] = h_nodes
            if not self.gru:
                cell_states[corr_index.data] = c_nodes

        # Reshape outputs
        outputs_return = Variable(torch.zeros(self.seq_length, numNodes, self.output_size))
        if self.use_cuda:
            outputs_return = outputs_return.cuda()
        for framenum in range(self.seq_length):
            for node in range(numNodes):
                outputs_return[framenum, node, :] = outputs[framenum*numNodes + node, :]
		#最终输出 outputs(20x3x5) hidden(3x128),cell(3x128)
        return outputs_return, hidden_states, cell_states

在这里插入图片描述

import torch
import numpy as np
from torch.autograd import Variable
import argparse
import os
import time
import pickle
import subprocess
from model import SocialModel
from utils import DataLoader
from grid import getSequenceGridMask
from helper import *


def main():
    parser = argparse.ArgumentParser()
    # RNN size parameter (dimension of the output/hidden state)
    parser.add_argument('--input_size', type=int, default=2)
    parser.add_argument('--output_size', type=int, default=5)
    # RNN size parameter (dimension of the output/hidden state)
    parser.add_argument('--rnn_size', type=int, default=128,
                        help='size of RNN hidden state')
    # Size of each batch parameter
    parser.add_argument('--batch_size', type=int, default=5,
                        help='minibatch size')
    # Length of sequence to be considered parameter
    parser.add_argument('--seq_length', type=int, default=20,
                        help='RNN sequence length')
    parser.add_argument('--pred_length', type=int, default=12,
                        help='prediction length')
    # Number of epochs parameter
    parser.add_argument('--num_epochs', type=int, default=30,
                        help='number of epochs')
    # Frequency at which the model should be saved parameter
    parser.add_argument('--save_every', type=int, default=400,
                        help='save frequency')
    # TODO: (resolve) Clipping gradients for now. No idea whether we should
    # Gradient value at which it should be clipped  不知道此参数是干嘛用的
    parser.add_argument('--grad_clip', type=float, default=10.,
                        help='clip gradients at this value')
    # Learning rate parameter
    parser.add_argument('--learning_rate', type=float, default=0.003,
                        help='learning rate')
    # Decay rate for the learning rate parameter
    parser.add_argument('--decay_rate', type=float, default=0.95,
                        help='decay rate for rmsprop')
    # Dropout not implemented.
    # Dropout probability parameter
    parser.add_argument('--dropout', type=float, default=0.5,
                        help='dropout probability')
    # Dimension of the embeddings parameter
    parser.add_argument('--embedding_size', type=int, default=64,
                        help='Embedding dimension for the spatial coordinates')
    # Size of neighborhood to be considered parameter
    parser.add_argument('--neighborhood_size', type=int, default=32,
                        help='Neighborhood size to be considered for social grid')
    # Size of the social grid parameter
    parser.add_argument('--grid_size', type=int, default=4,
                        help='Grid size of the social grid')
    # Maximum number of pedestrians to be considered
    parser.add_argument('--maxNumPeds', type=int, default=27,
                        help='Maximum Number of Pedestrians')

    # Lambda regularization parameter (L2)
    parser.add_argument('--lambda_param', type=float, default=0.0005,
                        help='L2 regularization parameter')
    # Cuda parameter
    parser.add_argument('--use_cuda', action="store_true", default=False,
                        help='Use GPU or not')
    # GRU parameter
    parser.add_argument('--gru', action="store_true", default=False,
                        help='True : GRU cell, False: LSTM cell')
    # drive option
    parser.add_argument('--drive', action="store_true", default=False,
                        help='Use Google drive or not')
    # number of validation will be used
    parser.add_argument('--num_validation', type=int, default=2,
                        help='Total number of validation dataset for validate accuracy')
    # frequency of validation
    parser.add_argument('--freq_validation', type=int, default=1,
                        help='Frequency number(epoch) of validation using validation data')
    # frequency of optimazer learning decay
    parser.add_argument('--freq_optimizer', type=int, default=8,
                        help='Frequency number(epoch) of learning decay for optimizer')
    # store grids in epoch 0 and use further.2 times faster -> Intensive memory use around 12 GB
    parser.add_argument('--grid', action="store_true", default=True,
                        help='Whether store grids and use further epoch')
    
    args = parser.parse_args()
    
    train(args)


def train(args):
    origin = (0,0)
    reference_point = (0,1)
    validation_dataset_executed = False
  
    prefix = ''
    f_prefix = '.'
    if args.drive is True:
      prefix='drive/semester_project/social_lstm_final/'
      f_prefix = 'drive/semester_project/social_lstm_final'

    
    if not os.path.isdir("log/"):
      print("Directory creation script is running...")
      subprocess.call([f_prefix+'/make_directories.bat'])
    #
    args.freq_validation = np.clip(args.freq_validation, 0, args.num_epochs)
    validation_epoch_list = list(range(args.freq_validation, args.num_epochs+1, args.freq_validation))
    validation_epoch_list[-1]-=1



    # Create the data loader object. This object would preprocess the data in terms of
    # batches each of size args.batch_size, of length args.seq_length
    dataloader = DataLoader(f_prefix, args.batch_size, args.seq_length, args.num_validation, forcePreProcess=True)

    model_name = "LSTM"
    method_name = "SOCIALLSTM"
    save_tar_name = method_name+"_lstm_model_"
    if args.gru:
        model_name = "GRU"
        save_tar_name = method_name+"_gru_model_"


    # Log directory
    log_directory = os.path.join(prefix, 'log/')  #日志存储目录
    plot_directory = os.path.join(prefix, 'plot/', method_name, model_name) # 绘图文件目录
    plot_train_file_directory = 'validation' #绘图目录中用于保存训练的子目录名称

    # Logging files 记录训练过程中的曲线数据训练周期的损失值等
    log_file_curve = open(os.path.join(log_directory, method_name, model_name,'log_curve.txt'), 'w+') #
    log_file = open(os.path.join(log_directory, method_name, model_name, 'val.txt'), 'w+') #用于记录验证过程中的日志信息,验证损失值等

    # model directory
    save_directory = os.path.join(prefix, 'model/')
    
    # Save the arguments int the config file
    with open(os.path.join(save_directory, method_name, model_name,'config.pkl'), 'wb') as f:
        pickle.dump(args, f)  #将序列化后的数据写入文件f中,内容以二进制形式写入到这个文件对象中

    # Path to store the checkpoint file
    def checkpoint_path(x):
        return os.path.join(save_directory, method_name, model_name, save_tar_name+str(x)+'.tar')

    # model creation
    net = SocialModel(args)
    if args.use_cuda:
        net = net.cuda()
	# 目前使用Adagrad进行优化
    #optimizer = torch.optim.RMSprop(net.parameters(), lr=args.learning_rate)
    optimizer = torch.optim.Adagrad(net.parameters(), weight_decay=args.lambda_param)
    #optimizer = torch.optim.Adam(net.parameters(), weight_decay=args.lambda_param)

    learning_rate = args.learning_rate

    best_val_loss = 100
    best_val_data_loss = 100

    smallest_err_val = 100000
    smallest_err_val_data = 100000


    best_epoch_val = 0
    best_epoch_val_data = 0

    best_err_epoch_val = 0
    best_err_epoch_val_data = 0

    all_epoch_results = []
    grids = []
    num_batch = 0
    dataset_pointer_ins_grid = -1

    [grids.append([]) for dataset in range(dataloader.get_len_of_dataset())]

    # Training
    for epoch in range(args.num_epochs):
        print('****************Training epoch beginning******************')
        if dataloader.additional_validation and (epoch-1) in validation_epoch_list:
            dataloader.switch_to_dataset_type(True)
        dataloader.reset_batch_pointer(valid=False)
        loss_epoch = 0

        # For each batch
        for batch in range(dataloader.num_batches):
            start = time.time()

            # Get batch data   大小是(5,20,一帧数据)batch_size sequence_len
            x, y, d , numPedsList, PedsList ,target_ids= dataloader.next_batch()
            loss_batch = 0
            
            #if we are in a new dataset, zero the counter of batch
            #TODO 这里改了
            if dataset_pointer_ins_grid != dataloader.dataset_pointer and epoch != 0:
                num_batch = 0
                dataset_pointer_ins_grid = dataloader.dataset_pointer


            # For each sequence
            for sequence in range(dataloader.batch_size):
                # Get the data corresponding to the current sequence
                x_seq ,_ , d_seq, numPedsList_seq, PedsList_seq = x[sequence], y[sequence], d[sequence], numPedsList[sequence], PedsList[sequence]
                target_id = target_ids[sequence]

                #get processing file name and then get dimensions of file
                folder_name = dataloader.get_directory_name_with_pointer(d_seq)
                dataset_data = dataloader.get_dataset_dimension(folder_name)

                #dense vector creation  返回20个序列中(20,maxPed,2) 和 字典look_up pedId :range(len(maxPed))
                x_seq, lookup_seq = dataloader.convert_proper_array(x_seq, numPedsList_seq, PedsList_seq)
                target_id_values = x_seq[0][lookup_seq[target_id], 0:2]

                #grid mask calculation and storage depending on grid parameter
                if(args.grid):
                    # TODO:这里改了
                    if(epoch == 0):  
                        grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq,args.neighborhood_size, args.grid_size, args.use_cuda)
                        grids[dataloader.dataset_pointer].append(grid_seq)
                    else:
                        grid_seq = grids[dataloader.dataset_pointer][(num_batch*dataloader.batch_size)+sequence]
                else:
                    grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq,args.neighborhood_size, args.grid_size, args.use_cuda)

                # vectorize trajectories in sequence    对于每个行人,其所有位置减去第一时刻位置得到相对位置
                x_seq, _ = vectorize_seq(x_seq, PedsList_seq, lookup_seq)

                if args.use_cuda:                    
                    x_seq = x_seq.cuda()


                #number of peds in this sequence per frame
                numNodes = len(lookup_seq)
                hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                if args.use_cuda:                    
                    hidden_states = hidden_states.cuda()
                cell_states = Variable(torch.zeros(numNodes, args.rnn_size))
                if args.use_cuda:                    
                    cell_states = cell_states.cuda()
                # Zero out gradients
                net.zero_grad()
                optimizer.zero_grad()             
                # Forward prop
                outputs, _, _ = net(x_seq, grid_seq, hidden_states, cell_states, PedsList_seq,numPedsList_seq ,dataloader, lookup_seq)                
                # Compute loss
                loss = Gaussian2DLikelihood(outputs, x_seq, PedsList_seq, lookup_seq)
                loss_batch += loss.item()

                # Compute gradients
                loss.backward()
                

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)

                # Update parameters
                optimizer.step()

            end = time.time()
            loss_batch = loss_batch / dataloader.batch_size
            loss_epoch += loss_batch
            num_batch+=1

            print('{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(epoch * dataloader.num_batches + batch,
                                                                                    args.num_epochs * dataloader.num_batches,
                                                                                    epoch,
                                                                                    loss_batch, end - start))

        loss_epoch /= dataloader.num_batches
        # Log loss values
        log_file_curve.write("Training epoch: "+str(epoch)+" loss: "+str(loss_epoch)+'\n')

        if dataloader.valid_num_batches > 0:
            print('****************Validation epoch beginning******************')

            # Validation
            dataloader.reset_batch_pointer(valid=True)
            loss_epoch = 0
            err_epoch = 0

            # For each batch
            for batch in range(dataloader.valid_num_batches):
                # Get batch data
                x, y, d , numPedsList, PedsList ,target_ids= dataloader.next_valid_batch()

                # Loss for this batch
                loss_batch = 0
                err_batch = 0


                # For each sequence
                for sequence in range(dataloader.batch_size):
                    # Get data corresponding to the current sequence
                    x_seq ,_ , d_seq, numPedsList_seq, PedsList_seq = x[sequence], y[sequence], d[sequence], numPedsList[sequence], PedsList[sequence]
                    target_id = target_ids[sequence]

                    #get processing file name and then get dimensions of file
                    folder_name = dataloader.get_directory_name_with_pointer(d_seq)
                    dataset_data = dataloader.get_dataset_dimension(folder_name)
                    
                    #dense vector creation
                    x_seq, lookup_seq = dataloader.convert_proper_array(x_seq, numPedsList_seq, PedsList_seq)

                    target_id_values = x_seq[0][lookup_seq[target_id], 0:2]
                    
                    #get grid mask
                    grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq, args.neighborhood_size, args.grid_size, args.use_cuda)

                    x_seq, first_values_dict = vectorize_seq(x_seq, PedsList_seq, lookup_seq)


                    # <---------------------- Experimental block ----------------------->
                    # x_seq = translate(x_seq, PedsList_seq, lookup_seq ,target_id_values)
                    # angle = angle_between(reference_point, (x_seq[1][lookup_seq[target_id], 0].data.numpy(), x_seq[1][lookup_seq[target_id], 1].data.numpy()))
                    # x_seq = rotate_traj_with_target_ped(x_seq, angle, PedsList_seq, lookup_seq)
                    # grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq, args.neighborhood_size, args.grid_size, args.use_cuda)
                    # x_seq, first_values_dict = vectorize_seq(x_seq, PedsList_seq, lookup_seq)


                    if args.use_cuda:                    
                        x_seq = x_seq.cuda()

                    #number of peds in this sequence per frame
                    numNodes = len(lookup_seq)

                    hidden_states = Variable(torch.zeros(numNodes, args.rnn_size))
                    if args.use_cuda:                    
                        hidden_states = hidden_states.cuda()
                    cell_states = Variable(torch.zeros(numNodes, args.rnn_size))
                    if args.use_cuda:                    
                        cell_states = cell_states.cuda()

                    # Forward prop
                    outputs, _, _ = net(x_seq[:-1], grid_seq[:-1], hidden_states, cell_states, PedsList_seq[:-1], numPedsList_seq , dataloader, lookup_seq)

                    # Compute loss
                    loss = Gaussian2DLikelihood(outputs, x_seq[1:], PedsList_seq[1:], lookup_seq)
                    # Extract the mean, std and corr of the bivariate Gaussian
                    mux, muy, sx, sy, corr = getCoef(outputs)
                    # Sample from the bivariate Gaussian
                    next_x, next_y = sample_gaussian_2d(mux.data, muy.data, sx.data, sy.data, corr.data, PedsList_seq[-1], lookup_seq)
                    next_vals = torch.FloatTensor(1,numNodes,2)
                    next_vals[:,:,0] = next_x
                    next_vals[:,:,1] = next_y
                    err = get_mean_error(next_vals, x_seq[-1].data[None, : ,:], [PedsList_seq[-1]], [PedsList_seq[-1]], args.use_cuda, lookup_seq)

                    loss_batch += loss.item()
                    err_batch += err

                loss_batch = loss_batch / dataloader.batch_size
                err_batch = err_batch / dataloader.batch_size
                loss_epoch += loss_batch
                err_epoch += err_batch

            if dataloader.valid_num_batches != 0:            
                loss_epoch = loss_epoch / dataloader.valid_num_batches
                err_epoch = err_epoch / dataloader.num_batches


                # Update best validation loss until now
                if loss_epoch < best_val_loss:
                    best_val_loss = loss_epoch
                    best_epoch_val = epoch

                if err_epoch<smallest_err_val:
                    smallest_err_val = err_epoch
                    best_err_epoch_val = epoch

                print('(epoch {}), valid_loss = {:.3f}, valid_err = {:.3f}'.format(epoch, loss_epoch, err_epoch))
                print('Best epoch', best_epoch_val, 'Best validation loss', best_val_loss, 'Best error epoch',best_err_epoch_val, 'Best error', smallest_err_val)
                log_file_curve.write("Validation epoch: "+str(epoch)+" loss: "+str(loss_epoch)+" err: "+str(err_epoch)+'\n')


        # Validation dataset
        if dataloader.additional_validation and (epoch) in validation_epoch_list:
            dataloader.switch_to_dataset_type()
            print('****************Validation with dataset epoch beginning******************')
            dataloader.reset_batch_pointer(valid=False)
            dataset_pointer_ins = dataloader.dataset_pointer
            validation_dataset_executed = True

            loss_epoch = 0
            err_epoch = 0
            f_err_epoch = 0
            num_of_batch = 0
            smallest_err = 100000

            #results of one epoch for all validation datasets
            epoch_result = []
            #results of one validation dataset
            results = []



            # For each batch
            for batch in range(dataloader.num_batches):
                # Get batch data
                x, y, d , numPedsList, PedsList ,target_ids = dataloader.next_batch()
                #TODO 这里改了

                if dataset_pointer_ins != dataloader.dataset_pointer:
                    #TODO 这里改了
                    if dataloader.dataset_pointer != 0:
                        print('Finished prosessed file : ', dataloader.get_file_name(-1),' Avarage error : ', err_epoch/num_of_batch)
                        num_of_batch = 0
                        epoch_result.append(results)

                    dataset_pointer_ins = dataloader.dataset_pointer
                    results = []



                # Loss for this batch
                loss_batch = 0
                err_batch = 0
                f_err_batch = 0

                # For each sequence
                for sequence in range(dataloader.batch_size):
                    # Get data corresponding to the current sequence
                    x_seq ,_ , d_seq, numPedsList_seq, PedsList_seq = x[sequence], y[sequence], d[sequence], numPedsList[sequence], PedsList[sequence]
                    target_id = target_ids[sequence]

                    #get processing file name and then get dimensions of file
                    folder_name = dataloader.get_directory_name_with_pointer(d_seq)
                    dataset_data = dataloader.get_dataset_dimension(folder_name)
                    
                    #dense vector creation
                    x_seq, lookup_seq = dataloader.convert_proper_array(x_seq, numPedsList_seq, PedsList_seq)
                    
                    #will be used for error calculation
                    orig_x_seq = x_seq.clone() 
                    
                    target_id_values = orig_x_seq[0][lookup_seq[target_id], 0:2]
                    
                    #grid mask calculation
                    grid_seq = getSequenceGridMask(x_seq, dataset_data, PedsList_seq, args.neighborhood_size, args.grid_size, args.use_cuda)
                    
                    if args.use_cuda:
                        x_seq = x_seq.cuda()
                        orig_x_seq = orig_x_seq.cuda()

                    #vectorize datapoints
                    x_seq, first_values_dict = vectorize_seq(x_seq, PedsList_seq, lookup_seq)
                    ret_x_seq, loss = sample_validation_data(x_seq, PedsList_seq, grid_seq, args, net, lookup_seq, numPedsList_seq, dataloader)

                    #revert the points back to original space
                    ret_x_seq = revert_seq(ret_x_seq, PedsList_seq, lookup_seq, first_values_dict)

                    # <---------------------- Experimental block revert----------------------->
                    # Revert the calculated coordinates back to original space:
                    # 1) Convert point from vectors to absolute coordinates
                    # 2) Rotate all trajectories in reverse angle
                    # 3) Translate all trajectories back to original space by adding the first frame value of target ped trajectory
                    
                    # *It works without problems which mean that it reverts a trajectory back completely
                    
                    # Possible problems:
                    # *Algoritmical errors caused by first experimental block -> High possiblity
                    # <------------------------------------------------------------------------>

                    # ret_x_seq = revert_seq(ret_x_seq, PedsList_seq, lookup_seq, first_values_dict)

                    # ret_x_seq = rotate_traj_with_target_ped(ret_x_seq, -angle, PedsList_seq, lookup_seq)

                    # ret_x_seq = translate(ret_x_seq, PedsList_seq, lookup_seq ,-target_id_values)

                    #get mean and final error
                    err = get_mean_error(ret_x_seq.data, orig_x_seq.data, PedsList_seq, PedsList_seq, args.use_cuda, lookup_seq)
                    f_err = get_final_error(ret_x_seq.data, orig_x_seq.data, PedsList_seq, PedsList_seq, lookup_seq)
                    
                    loss_batch += loss.item()
                    err_batch += err
                    f_err_batch += f_err
                    print('Current file : ', dataloader.get_file_name(0),' Batch : ', batch+1, ' Sequence: ', sequence+1, ' Sequence mean error: ', err,' Sequence final error: ',f_err,' time: ', end - start)
                    results.append((orig_x_seq.data.cpu().numpy(), ret_x_seq.data.cpu().numpy(), PedsList_seq, lookup_seq, dataloader.get_frame_sequence(args.seq_length), target_id))

                loss_batch = loss_batch / dataloader.batch_size
                err_batch = err_batch / dataloader.batch_size
                f_err_batch = f_err_batch / dataloader.batch_size
                num_of_batch += 1
                loss_epoch += loss_batch
                err_epoch += err_batch
                f_err_epoch += f_err_batch

            epoch_result.append(results)
            all_epoch_results.append(epoch_result)


            if dataloader.num_batches != 0:            
                loss_epoch = loss_epoch / dataloader.num_batches
                err_epoch = err_epoch / dataloader.num_batches
                f_err_epoch = f_err_epoch / dataloader.num_batches
                avarage_err = (err_epoch + f_err_epoch)/2

                # Update best validation loss until now
                if loss_epoch < best_val_data_loss:
                    best_val_data_loss = loss_epoch
                    best_epoch_val_data = epoch

                if avarage_err<smallest_err_val_data:
                    smallest_err_val_data = avarage_err
                    best_err_epoch_val_data = epoch

                print('(epoch {}), valid_loss = {:.3f}, valid_mean_err = {:.3f}, valid_final_err = {:.3f}'.format(epoch, loss_epoch, err_epoch, f_err_epoch))
                print('Best epoch', best_epoch_val_data, 'Best validation loss', best_val_data_loss, 'Best error epoch',best_err_epoch_val_data, 'Best error', smallest_err_val_data)
                log_file_curve.write("Validation dataset epoch: "+str(epoch)+" loss: "+str(loss_epoch)+" mean_err: "+str(err_epoch)+'final_err: '+str(f_err_epoch)+'\n')

            optimizer = time_lr_scheduler(optimizer, epoch, lr_decay_epoch = args.freq_optimizer)


        # Save the model after each epoch
        print('Saving model')
        torch.save({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, checkpoint_path(epoch))




    if dataloader.valid_num_batches != 0:        
        print('Best epoch', best_epoch_val, 'Best validation Loss', best_val_loss, 'Best error epoch',best_err_epoch_val, 'Best error', smallest_err_val)
        # Log the best epoch and best validation loss
        log_file.write('Validation Best epoch:'+str(best_epoch_val)+','+' Best validation Loss: '+str(best_val_loss))

    if dataloader.additional_validation:
        print('Best epoch acording to validation dataset', best_epoch_val_data, 'Best validation Loss', best_val_data_loss, 'Best error epoch',best_err_epoch_val_data, 'Best error', smallest_err_val_data)
        log_file.write("Validation dataset Best epoch: "+str(best_epoch_val_data)+','+' Best validation Loss: '+str(best_val_data_loss)+'\n')
        #dataloader.write_to_plot_file(all_epoch_results[best_epoch_val_data], plot_directory)

    #elif dataloader.valid_num_batches != 0:
    #    dataloader.write_to_plot_file(all_epoch_results[best_epoch_val], plot_directory)

    #else:
    if validation_dataset_executed:
        dataloader.switch_to_dataset_type(load_data=False)
        create_directories(plot_directory, [plot_train_file_directory])
        dataloader.write_to_plot_file(all_epoch_results[len(all_epoch_results)-1], os.path.join(plot_directory, plot_train_file_directory))

    # Close logging files
    log_file.close()
    log_file_curve.close()



if __name__ == '__main__':
    main()
    

DataLoader类

重点的就是三个函数
frame_preprocess(self, data_dirs, data_file, validation_set = False): 通过传入的文件夹,把文件夹中的数据集文件全部处理得到all_frame_data, frameList_data, numPeds_data, valid_numPeds_data, valid_frame_data, pedsList_data, valid_pedsList_data, target_ids, orig_data 并且保存到文件当中
load_preprocessed(self, data_file, validation_set = False): 从文件中加载数据,加载到DataLoader类中
next_batch(self): 用来从DataLoader中返回一批数据x(0-19帧), y(1-20帧), d(数据集编号) , numPedsList(一个序列中共多少行人), PedsList(20个帧,每个帧中所有行人) ,target_ids

class DataLoader():

    def __init__(self,f_prefix, batch_size=5, seq_length=20, num_of_validation = 0, forcePreProcess=False, infer=False, generate = False):
        '''
        Initialiser function for the DataLoader class
        params:
        batch_size : Size of the mini-batch
        seq_length : Sequence length to be considered
        num_of_validation : number of validation dataset will be used
        infer : flag for test mode
        generate : flag for data generation mode
        forcePreProcess : Flag to forcefully preprocess the data again from csv files
        '''
        # base test files
        base_test_dataset=  ['/data/test/biwi/biwi_eth.txt', 
                        '/data/test/crowds/crowds_zara01.txt',
                        '/data/test/crowds/uni_examples.txt', 
                        '/data/test/stanford/coupa_0.txt',
                         '/data/test/stanford/coupa_1.txt', '/data/test/stanford/gates_2.txt','/data/test/stanford/hyang_0.txt','/data/test/stanford/hyang_1.txt','/data/test/stanford/hyang_3.txt','/data/test/stanford/hyang_8.txt',
                          '/data/test/stanford/little_0.txt','/data/test/stanford/little_1.txt','/data/test/stanford/little_2.txt','/data/test/stanford/little_3.txt','/data/test/stanford/nexus_5.txt','/data/test/stanford/nexus_6.txt',
                          '/data/test/stanford/quad_0.txt','/data/test/stanford/quad_1.txt','/data/test/stanford/quad_2.txt','/data/test/stanford/quad_3.txt'
                          ]
        #base train files
        base_train_dataset = ['/data/train/biwi/biwi_hotel.txt', 
                        #'/data/train/crowds/arxiepiskopi1.txt','/data/train/crowds/crowds_zara02.txt',
                        #'/data/train/crowds/crowds_zara03.txt','/data/train/crowds/students001.txt','/data/train/crowds/students003.txt',
                        #'/data/train/mot/PETS09-S2L1.txt',
                        #'/data/train/stanford/bookstore_0.txt','/data/train/stanford/bookstore_1.txt','/data/train/stanford/bookstore_2.txt','/data/train/stanford/bookstore_3.txt','/data/train/stanford/coupa_3.txt','/data/train/stanford/deathCircle_0.txt','/data/train/stanford/deathCircle_1.txt','/data/train/stanford/deathCircle_2.txt','/data/train/stanford/deathCircle_3.txt',
                        #'/data/train/stanford/deathCircle_4.txt','/data/train/stanford/gates_0.txt','/data/train/stanford/gates_1.txt','/data/train/stanford/gates_3.txt','/data/train/stanford/gates_4.txt','/data/train/stanford/gates_5.txt','/data/train/stanford/gates_6.txt','/data/train/stanford/gates_7.txt','/data/train/stanford/gates_8.txt','/data/train/stanford/hyang_4.txt',
                        #'/data/train/stanford/hyang_5.txt','/data/train/stanford/hyang_6.txt','/data/train/stanford/hyang_9.txt','/data/train/stanford/nexus_0.txt','/data/train/stanford/nexus_1.txt','/data/train/stanford/nexus_2.txt','/data/train/stanford/nexus_3.txt','/data/train/stanford/nexus_4.txt','/data/train/stanford/nexus_7.txt','/data/train/stanford/nexus_8.txt','/data/train/stanford/nexus_9.txt'
                        ]
        # dimensions of each file set
        self.dataset_dimensions = {'biwi':[720, 576], 'crowds':[720, 576], 'stanford':[595, 326], 'mot':[768, 576]}
        
        # List of data directories where raw data resides
        self.base_train_path = 'data/train/'
        self.base_test_path = 'data/test/'
        self.base_validation_path = 'data/validation/'

        # check infer flag, if true choose test directory as base directory
        if infer is False:
            self.base_data_dirs = base_train_dataset
        else:
            self.base_data_dirs = base_test_dataset

        # get all files using python os and base directories
        self.train_dataset = self.get_dataset_path(self.base_train_path, f_prefix)
        self.test_dataset = self.get_dataset_path(self.base_test_path, f_prefix)
        self.validation_dataset = self.get_dataset_path(self.base_validation_path, f_prefix)


        # if generate mode, use directly train base files
        if generate:
            self.train_dataset = [os.path.join(f_prefix, dataset[1:]) for dataset in base_train_dataset]

        #request of use of validation dataset
        if num_of_validation>0:
            self.additional_validation = True
        else:
            self.additional_validation = False

        # check validation dataset availibility and clip the reuqested number if it is bigger than available validation dataset
        if self.additional_validation:
            if len(self.validation_dataset) == 0:
                print("There is no validation dataset.Aborted.")
                self.additional_validation = False
            else:
                num_of_validation = np.clip(num_of_validation, 0, len(self.validation_dataset))
                self.validation_dataset = random.sample(self.validation_dataset, num_of_validation)

        # if not infer mode, use train dataset
        if infer is False:
            self.data_dirs = self.train_dataset
        else:
            # use validation dataset
            if self.additional_validation:
                self.data_dirs = self.validation_dataset
            # use test dataset
            else:
                self.data_dirs = self.test_dataset


        self.infer = infer
        self.generate = generate

        # Number of datasets
        self.numDatasets = len(self.data_dirs)

        # array for keepinng target ped ids for each sequence
        self.target_ids = []

        # Data directory where the pre-processed pickle file resides
        self.train_data_dir = os.path.join(f_prefix, self.base_train_path)
        self.test_data_dir = os.path.join(f_prefix, self.base_test_path)
        self.val_data_dir = os.path.join(f_prefix, self.base_validation_path)


        # Store the arguments
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.orig_seq_lenght = seq_length

        # Validation arguments
        self.val_fraction = 0

        # Define the path in which the process data would be stored
        self.data_file_tr = os.path.join(self.train_data_dir, "trajectories_train.cpkl")        
        self.data_file_te = os.path.join(self.base_test_path, "trajectories_test.cpkl")
        self.data_file_vl = os.path.join(self.val_data_dir, "trajectories_val.cpkl")


        # for creating a dict key: folder names, values: files in this folder
        self.create_folder_file_dict()

        if self.additional_validation:
        # If the file doesn't exist or forcePreProcess is true
            if not(os.path.exists(self.data_file_vl)) or forcePreProcess:
                print("Creating pre-processed validation data from raw data")
                # Preprocess the data from the csv files of the datasets
                # Note that this data is processed in frames
                self.frame_preprocess(self.validation_dataset, self.data_file_vl, self.additional_validation)

        if self.infer:  #如果是测试模式
        # if infer mode, and no additional files -> test preprocessing
            if not self.additional_validation:
                if not(os.path.exists(self.data_file_te)) or forcePreProcess:
                    print("Creating pre-processed test data from raw data")
                    # Preprocess the data from the csv files of the datasets
                    # Note that this data is processed in frames
                    print("Working on directory: ", self.data_file_te)
                    self.frame_preprocess(self.data_dirs, self.data_file_te)
            # if infer mode, and there are additional validation files -> validation dataset visualization
            else:
                print("Validation visualization file will be created")
        
        # if not infer mode  如果是训练模式
        else:
            # If the file doesn't exist or forcePreProcess is true -> training pre-process
            if not(os.path.exists(self.data_file_tr)) or forcePreProcess:
                print("Creating pre-processed training data from raw data")
                # Preprocess the data from the csv files of the datasets
                # Note that this data is processed in frames
                self.frame_preprocess(self.data_dirs, self.data_file_tr)

        if self.infer:
            # Load the processed data from the pickle file
            if not self.additional_validation: #test mode
                self.load_preprocessed(self.data_file_te)
            else:  # validation mode
                self.load_preprocessed(self.data_file_vl, True)

        else: # training mode
            self.load_preprocessed(self.data_file_tr)

        # Reset all the data pointers of the dataloader object
        self.reset_batch_pointer(valid=False)
        self.reset_batch_pointer(valid=True)

    def frame_preprocess(self, data_dirs, data_file, validation_set = False):
        '''
        Function that will pre-process the pixel_pos.csv files of each dataset
        into data with occupancy grid that can be used
        params:
        data_dirs : List of directories where raw data resides
        data_file : The file into which all the pre-processed data needs to be stored
        validation_set: true when a dataset is in validation set
        '''
        # all_frame_data would be a list of list of numpy arrays corresponding to each dataset
        # Each numpy array will correspond to a frame and would be of size (numPeds, 3) each row
        # containing pedID, x, y   也就是[frame,numPeds,3]
        all_frame_data = []
        # Validation frame data
        valid_frame_data = []
        # frameList_data would be a list of lists corresponding to each dataset
        # Each list would contain the frameIds of all the frames in the dataset
        frameList_data = []
        valid_numPeds_data= []
        # numPeds_data would be a list of lists corresponding to each dataset
        # Ech list would contain the number of pedestrians in each frame in the dataset
        numPeds_data = []
        

        #each list includes ped ids of this frame
        pedsList_data = []
        valid_pedsList_data = []
        # target ped ids for each sequence
        target_ids = []
        orig_data = []



        # Index of the current dataset
        dataset_index = 0

        # For each dataset
        for directory in data_dirs:

            # Load the data from the txt file
            print("Now processing: ", directory)
            column_names = ['frame_num','ped_id','y','x']

            # if training mode, read train file to pandas dataframe and process
            if self.infer is False:
                df = pd.read_csv(directory, dtype={'frame_num':'int','ped_id':'int' }, delimiter = ' ',  header=None, names=column_names)
                self.target_ids = np.array(df.drop_duplicates(subset=['ped_id'], keep='first', inplace=False)['ped_id'])


            else:
                # if validation mode, read validation file to pandas dataframe and process
                if self.additional_validation:
                    df = pd.read_csv(directory, dtype={'frame_num':'int','ped_id':'int' }, delimiter = ' ',  header=None, names=column_names)
                    self.target_ids = np.array(df.drop_duplicates(subset={'ped_id'}, keep='first', inplace=False)['ped_id'])

                # if test mode, read test file to pandas dataframe and process
                else:
                    column_names = ['frame_num','ped_id','y','x']
                    df = pd.read_csv(directory, dtype={'frame_num':'int','ped_id':'int' }, delimiter = ' ',  header=None, names=column_names, converters = {c:lambda x: float('nan') if x == '?' else float(x) for c in ['y','x']})
                    self.target_ids = np.array(df[df['y'].isnull()].drop_duplicates(subset='ped_id', keep='first', inplace=False)['ped_id'])

            # convert pandas -> numpy array
            data = np.array(df)

            # keep original copy of file
            orig_data.append(data)

            #swap x and y points (in txt file it is like -> y,x)
            data = np.swapaxes(data,0,1)  #(2900,4)->(4,2900)
            
            # get frame numbers
            frameList = data[0, :].tolist()


            # Number of frames
            numFrames = len(frameList)

            # Add the list of frameIDs to the frameList_data
            frameList_data.append(frameList)
            # Initialize the list of numPeds for the current dataset
            numPeds_data.append([])
            valid_numPeds_data.append([])

            # Initialize the list of numpy arrays for the current dataset
            all_frame_data.append([])
            # Initialize the list of numpy arrays for the current dataset
            valid_frame_data.append([])

            # list of peds for each frame
            pedsList_data.append([])
            valid_pedsList_data.append([])

            target_ids.append(self.target_ids)



            for ind, frame in enumerate(frameList):


                # Extract all pedestrians in current frame
                pedsInFrame = data[: , data[0, :] == frame]
                #print("peds in %d: %s"%(frame,str(pedsInFrame)))



                # Extract peds list
                pedsList = pedsInFrame[1, :].tolist()

                # Add number of peds in the current frame to the stored data


                # Initialize the row of the numpy array
                pedsWithPos = []

                # For each ped in the current frame
                for ped in pedsList:
                    # Extract their x and y positions
                    current_x = pedsInFrame[3, pedsInFrame[1, :] == ped][0]
                    current_y = pedsInFrame[2, pedsInFrame[1, :] == ped][0]

                    # Add their pedID, x, y to the row of the numpy array
                    pedsWithPos.append([ped, current_x, current_y])

                # At inference time, data generation and if dataset is a validation dataset, no validation data
                if (ind >= numFrames * self.val_fraction) or (self.infer) or (self.generate) or (validation_set):
                    # Add the details of all the peds in the current frame to all_frame_data
                    all_frame_data[dataset_index].append(np.array(pedsWithPos))
                    pedsList_data[dataset_index].append(pedsList)
                    numPeds_data[dataset_index].append(len(pedsList))


                else:
                    valid_frame_data[dataset_index].append(np.array(pedsWithPos))
                    valid_pedsList_data[dataset_index].append(pedsList)
                    valid_numPeds_data[dataset_index].append(len(pedsList))



            dataset_index += 1
        # Save the arrays in the pickle file
        f = open(data_file, "wb")
        pickle.dump((all_frame_data, frameList_data, numPeds_data, valid_numPeds_data, valid_frame_data, pedsList_data, valid_pedsList_data, target_ids, orig_data), f, protocol=2)
        f.close()


    def load_preprocessed(self, data_file, validation_set = False):
        '''
        Function to load the pre-processed data into the DataLoader object
        params:
        data_file : the path to the pickled data file
        validation_set : flag for validation dataset
        '''
        # Load data from the pickled file
        if(validation_set):
            print("Loading validaton datasets: ", data_file)
        else:
            print("Loading train or test dataset: ", data_file)

        f = open(data_file, 'rb')
        self.raw_data = pickle.load(f)
        f.close()

        # Get all the data from the pickle file
        self.data = self.raw_data[0]
        self.frameList = self.raw_data[1]
        self.numPedsList = self.raw_data[2]
        self.valid_numPedsList = self.raw_data[3]
        self.valid_data = self.raw_data[4]
        self.pedsList = self.raw_data[5]
        self.valid_pedsList = self.raw_data[6]
        self.target_ids = self.raw_data[7]
        self.orig_data = self.raw_data[8]



        counter = 0
        valid_counter = 0
        print('Sequence size(frame) ------>',self.seq_length)
        print('One batch size (frame)--->-', self.batch_size*self.seq_length)

        # For each dataset
        for dataset in range(len(self.data)):
            # get the frame data for the current dataset
            all_frame_data = self.data[dataset]
            valid_frame_data = self.valid_data[dataset]
            dataset_name = self.data_dirs[dataset].split('/')[-1]
            # calculate number of sequence 
            num_seq_in_dataset = int(len(all_frame_data) / (self.seq_length))
            num_valid_seq_in_dataset = int(len(valid_frame_data) / (self.seq_length))
            if not validation_set:
                print('Training data from training dataset(name, # frame, #sequence)--> ', dataset_name, ':', len(all_frame_data),':', (num_seq_in_dataset))
                print('Validation data from training dataset(name, # frame, #sequence)--> ', dataset_name, ':', len(valid_frame_data),':', (num_valid_seq_in_dataset))
            else: 
                print('Validation data from validation dataset(name, # frame, #sequence)--> ', dataset_name, ':', len(all_frame_data),':', (num_seq_in_dataset))

            # Increment the counter with the number of sequences in the current dataset
            counter += num_seq_in_dataset
            valid_counter += num_valid_seq_in_dataset

        # Calculate the number of batches
        self.num_batches = int(counter/self.batch_size)
        self.valid_num_batches = int(valid_counter/self.batch_size)


        if not validation_set:
            print('Total number of training batches:', self.num_batches)
            print('Total number of validation batches:', self.valid_num_batches)
        else:
            print('Total number of validation batches:', self.num_batches)

        # self.valid_num_batches = self.valid_num_batches * 2
  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值