C3D源码解析

C3D源码解析

论文链接:http://vlg.cs.dartmouth.edu/c3d/c3d_video.pdf

代码链接:https://github.com/jfzhang95/pytorch-video-recognition

1.源码准备

git clone --recursive https://github.com/jfzhang95/pytorch-video-recognition.git

下载完成后得到C3D源码

2.源码结构

文件名称功能
train.py训练脚本
mypath.py配置数据集和预训练模型的路径
dataest.py数据读取和数据处理脚本
C3D_model.pyC3D模型网络结构构建脚本
ucf101-caffe.path预训练模型

接下来对一些重要文件, 将一一讲解,并且说清楚数据流的走向和函数调用关系。

3.源码分析(准备阶段)

3.1 数据准备

dataset.py的主要功能是对数据集进行读取,对数据集进行处理,获取对应的帧图片数据集和对应的动作、标签相对应的文档。

它首先定义了一个类VideoDataset,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。

(1)__init__函数

__init__函数的功能大致分为以下三个部分(1. 初始化类VideoDataset,并设置一些参数和参数默认值; 2. 生成视频对应的帧视频数据集; 3. 生成视频动作标签的txt文档–看着有点乱,有心的话可以自己封装一下),还有一些定义的函数,下面会逐步讲解。

第一部分:初始化类VideoDataset,并设置一些参数和参数默认值;

	def __init__(self, dataset='ucf101', split='train', clip_len=16, preprocess=False):	
     	self.root_dir, self.output_dir = Path.db_dir(dataset) #获取数据集的源路径和输出路径
        folder = os.path.join(self.output_dir, split) # 获取对应分组的的路径
        self.clip_len = clip_len # 16帧图片的意思
        self.split = split # 有三组 train val test

        # The following three parameters are chosen as described in the paper section 4.1
        # 图片的高和宽的变化过程(h*w-->128*171-->112*112)
        self.resize_height = 128
        self.resize_width = 171
        self.crop_size = 112

第二部分: 生成视频对应的帧视频数据集;

# check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
        if not self.check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You need to download it from official website.')
        # check_preprocess()判断是否存在Dataset的输出路径,若不存在preprocess()则创建,并在其中生成对应的帧图片的数据集
        if (not self.check_preprocess()) or preprocess:
            print('Preprocessing of {} dataset, this will take long, but it will be done only once.'.format(dataset))
            self.preprocess() 

第三部分: 生成视频动作标签的txt文档;

# Obtain all the filenames of files inside all the class folders
        # Going through each class folder one at a time
    # fnames-->所有类别里的动作视频的集合; labels-->动作视频对应的标签
        self.fnames, labels = [], []
        for label in sorted(os.listdir(folder)):
            for fname in os.listdir(os.path.join(folder, label)):
                self.fnames.append(os.path.join(folder, label, fname))
                labels.append(label)

        assert len(labels) == len(self.fnames)
        print('Number of {} videos: {:d}'.format(split, len(self.fnames)))

        # Prepare a mapping between the label names (strings) and indices (ints)--> label和对应的数字标签
        self.label2index = {label: index for index, label in enumerate(sorted(set(labels)))}
        # Convert the list of label names into an array of label indices-->转化为数字标签
        self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)
		# 生成对应的动作和数字标签的txt文档
        if dataset == "ucf101":
            if not os.path.exists('dataloaders/ucf_labels.txt'):
                with open('dataloaders/ucf_labels.txt', 'w') as f:
                    for id, label in enumerate(sorted(self.label2index)):
                        f.writelines(str(id+1) + ' ' + label + '\n')

        elif dataset == 'hmdb51':
            if not os.path.exists('dataloaders/hmdb_labels.txt'):
                with open('dataloaders/hmdb_labels.txt', 'w') as f:
                    for id, label in enumerate(sorted(self.label2index)):
                        f.writelines(str(id+1) + ' ' + label + '\n')

接下来介绍一些VideoDataset类的重要函数:

(2)__len__函数:
    # 返回所有动作视频的总数
    def __len__(self):
        return len(self.fnames)
(3)__getitem__函数:
    def __getitem__(self, index):
        # Loading and preprocessing.
        buffer = self.load_frames(self.fnames[index]) #加载一个视频生成的帧图片[frames,h,w,3]-->[frames,128,171,3]
        buffer = self.crop(buffer, self.clip_len, self.crop_size) # [16,112,112,3]
        labels = np.array(self.label_array[index]) #  转化为数组

        if self.split == 'test':
            # Perform data augmentation
            buffer = self.randomflip(buffer) # 增强数据集
        buffer = self.normalize(buffer) # 归一化
        buffer = self.to_tensor(buffer) # [3,16,112,112]
        return torch.from_numpy(buffer), torch.from_numpy(labels) #以数组的形式返回
(4)check_intergrity函数:
    # check_integrity()判断是否存在Dataset的源路径,若不存在,则报错
    def check_integrity(self):
        if not os.path.exists(self.root_dir):
            return False
        else:
            return True
(5)chech_preprocess函数:
    # 检查输出路径是否存在,若不存在,则报错;检查输出路径的数据集图片格式是否正确,若不正确则报错
    def check_preprocess(self):
        # TODO: Check image size in output_dir
        if not os.path.exists(self.output_dir):
            return False
        elif not os.path.exists(os.path.join(self.output_dir, 'train')):
            return False

        for ii, video_class in enumerate(os.listdir(os.path.join(self.output_dir, 'train'))):
            for video in os.listdir(os.path.join(self.output_dir, 'train', video_class)):
                video_name = os.path.join(os.path.join(self.output_dir, 'train', video_class, video),
                                    sorted(os.listdir(os.path.join(self.output_dir, 'train', video_class, video)))[0])
                image = cv2.imread(video_name)
                if np.shape(image)[0] != 128 or np.shape(image)[1] != 171:
                    return False
                else:
                    break

            if ii == 10:
                break

        return True
(6)preprocess函数:
    def preprocess(self):
        # 创建对应的分组路径
        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)
            os.mkdir(os.path.join(self.output_dir, 'train'))
            os.mkdir(os.path.join(self.output_dir, 'val'))
            os.mkdir(os.path.join(self.output_dir, 'test'))

        # Split train/val/test sets-->划分train/val/test的数据集 0.6/0.2/0.2
        for file in os.listdir(self.root_dir):
            file_path = os.path.join(self.root_dir, file)
            video_files = [name for name in os.listdir(file_path)]

            train_and_valid, test = train_test_split(video_files, test_size=0.2, random_state=42)
            train, val = train_test_split(train_and_valid, test_size=0.2, random_state=42)

            train_dir = os.path.join(self.output_dir, 'train', file)
            val_dir = os.path.join(self.output_dir, 'val', file)
            test_dir = os.path.join(self.output_dir, 'test', file)

            if not os.path.exists(train_dir):
                train_dir = train_dir.replace("\\", "/") # windows和linux系统的区别,若在Linux系统下不必添加
                print("train "+train_dir)
                os.mkdir(train_dir)

            if not os.path.exists(val_dir):
                val_dir = val_dir.replace("\\", "/")
                os.mkdir(val_dir)
            if not os.path.exists(test_dir):
                test_dir = test_dir.replace("\\", "/")
                os.mkdir(test_dir)

            for video in train:
                self.process_video(video, file, train_dir) #把视频转化为数组的形式表示

            for video in val:
                self.process_video(video, file, val_dir)

            for video in test:
                self.process_video(video, file, test_dir)

        print('Preprocessing finished.')
(7)process_video函数:
    def process_video(self, video, action_name, save_dir):
        # Initialize a VideoCapture object to read video data into a numpy array
        video_filename = video.split('.')[0]  # 获取是视频名
        if not os.path.exists(os.path.join(save_dir, video_filename)):
            os.mkdir(os.path.join(save_dir, video_filename))  # 创建视频对应的文件夹
		#读视频
        capture = cv2.VideoCapture(os.path.join(self.root_dir, action_name, video))
		# 读取视频的帧数、高和宽
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))

        # Make sure splited video has at least 16 frames-->确保视频至少16帧
        EXTRACT_FREQUENCY = 4
        if frame_count // EXTRACT_FREQUENCY <= 16:
            EXTRACT_FREQUENCY -= 1
            if frame_count // EXTRACT_FREQUENCY <= 16:
                EXTRACT_FREQUENCY -= 1
                if frame_count // EXTRACT_FREQUENCY <= 16:
                    EXTRACT_FREQUENCY -= 1

        count = 0
        i = 0
        retaining = True
		# 把视频的一帧的高和宽修改成128.171,并命名保存.jpg的图片
        while (count < frame_count and retaining):
            retaining, frame = capture.read()
            if frame is None:
                continue

            if count % EXTRACT_FREQUENCY == 0:
                if (frame_height != self.resize_height) or (frame_width != self.resize_width):
                    frame = cv2.resize(frame, (self.resize_width, self.resize_height))
                cv2.imwrite(filename=os.path.join(save_dir, video_filename, '0000{}.jpg'.format(str(i))), img=frame)
                i += 1
            count += 1

        # Release the VideoCapture once it is no longer needed-->释放资源
        capture.release()
(8)randomflip函数:
def randomflip(self, buffer):
    """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
	# 数据集以0.5的概率翻转,增强数据集
    if np.random.random() < 0.5:
        for i, frame in enumerate(buffer):
            frame = cv2.flip(buffer[i], flipCode=1)
            buffer[i] = cv2.flip(frame, flipCode=1)

    return buffer
(9)normalize函数:
def normalize(self, buffer):
    for i, frame in enumerate(buffer):
        frame -= np.array([[[90.0, 98.0, 102.0]]])
        buffer[i] = frame

    return buffer
(10)to_tensor函数:
# [0,1,2,3]-->[3,0,1,2]  进行维度的变换
def to_tensor(self, buffer):
    return buffer.transpose((3, 0, 1, 2))
(11)load_frames函数:
# #加载一个视频生成的帧图片[frames,h,w,3]-->[frames,128,171,3]
def load_frames(self, file_dir):
    frames = sorted([os.path.join(file_dir, img) for img in os.listdir(file_dir)])
    frame_count = len(frames)
    buffer = np.empty((frame_count, self.resize_height, self.resize_width, 3), np.dtype('float32'))
    for i, frame_name in enumerate(frames):
        frame = np.array(cv2.imread(frame_name)).astype(np.float64)
        buffer[i] = frame

    return buffer

接下来额外补充一下路径函数:

(12)mypath.py
class Path(object):
    @staticmethod
    def db_dir(database):
        if database == 'ucf101':
            # folder that contains class labels
            root_dir = 'data/UCF-101' #  数据集的源路径

            # Save preprocess data into output_dir
            output_dir = 'data/output/ucf101' # 生成数据集的输出路径

            return root_dir, output_dir
        elif database == 'hmdb51':
            # folder that contains class labels
            root_dir = '/Path/to/hmdb-51'

            output_dir = '/path/to/VAR/hmdb51'

            return root_dir, output_dir
        else:
            print('Database {} not available.'.format(database))
            raise NotImplementedError

    @staticmethod
    def model_dir():
        return 'models/ucf101-caffe.pth' # 预训练模型的路径
3.2 模型设计

C3D_model.py的主要功能是对之后的训练模型进行准备。使用一些C3D作为基础模型,对最后一层全连接层就行修改,得到我们所需的网络模型,接下来介绍一些里面的函数。

(1)__load_pretrained_weight函数:
    def __load_pretrained_weights(self):
        """Initialiaze network."""
        # corresp_name里包含每一层对应的参数:w、b
        corresp_name = {
                        # Conv1
                        "features.0.weight": "conv1.weight",
                        "features.0.bias": "conv1.bias",
                        # Conv2
                        "features.3.weight": "conv2.weight",
                        "features.3.bias": "conv2.bias",
                        # Conv3a
                        "features.6.weight": "conv3a.weight",
                        "features.6.bias": "conv3a.bias",
                        # Conv3b
                        "features.8.weight": "conv3b.weight",
                        "features.8.bias": "conv3b.bias",
                        # Conv4a
                        "features.11.weight": "conv4a.weight",
                        "features.11.bias": "conv4a.bias",
                        # Conv4b
                        "features.13.weight": "conv4b.weight",
                        "features.13.bias": "conv4b.bias",
                        # Conv5a
                        "features.16.weight": "conv5a.weight",
                        "features.16.bias": "conv5a.bias",
                         # Conv5b
                        "features.18.weight": "conv5b.weight",
                        "features.18.bias": "conv5b.bias",
                        # fc6
                        "classifier.0.weight": "fc6.weight",
                        "classifier.0.bias": "fc6.bias",
                        # fc7
                        "classifier.3.weight": "fc7.weight",
                        "classifier.3.bias": "fc7.bias",
                        }
	# 参数模型初始化
        p_dict = torch.load(Path.model_dir())
        s_dict = self.state_dict()
        for name in p_dict:
            if name not in corresp_name:
                continue
            s_dict[corresp_name[name]] = p_dict[name]
        self.load_state_dict(s_dict)
(2)__int_weight函数:
# conv3d层进行kaiming初始化;BN层w初始化为1,b初始化为0。
def __init_weight(self):
    for m in self.modules():
        if isinstance(m, nn.Conv3d):
            # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            # m.weight.data.normal_(0, math.sqrt(2. / n))
            torch.nn.init.kaiming_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm3d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
(3)get_1x_lr_params函数:
def get_1x_lr_params(model):
    """
    This generator returns all the parameters for conv and two fc layers of the net.
    以generator的形式返回所有的卷积层和前两个全连接层的参数
    """
    b = [model.conv1, model.conv2, model.conv3a, model.conv3b, model.conv4a, model.conv4b,
         model.conv5a, model.conv5b, model.fc6, model.fc7]
    for i in range(len(b)):
        for k in b[i].parameters():
            if k.requires_grad:
                yield k # yield形成generator的形式返回

(4)get_10x_lr_params函数:

def get_10x_lr_params(model):
    """
    This generator returns all the parameters for the last fc layer of the net.
    相似的  不再解释
    """
    b = [model.fc8]
    for j in range(len(b)):
        for k in b[j].parameters():
            if k.requires_grad:
                yield k

4.源码分析(训练部分)

接下来我们把train.py的一些重要源码解析

import timeit
from datetime import datetime
import socket
import os
import glob
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

from dataloaders.dataset import VideoDataset
from network import C3D_model, R2Plus1D_model, R3D_model

# Use GPU if available else revert to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device being used:", device)

nEpochs = 10  # Number of epochs for training-->训练周期
resume_epoch = 0  # Default is 0, change if want to resume-->默认为0
useTest = True # See evolution of the test set when training-->进行test测试的一个标志
nTestInterval = 20 # Run on test set every nTestInterval epochs-->每20个周期进行一次测试
snapshot = 50 # Store a model every snapshot epochs-->每50个周期保存一次预训练模型
lr = 1e-3 # Learning rate

dataset = 'ucf101' # Options: hmdb51 or ucf101

if dataset == 'hmdb51':
    num_classes = 51
elif dataset == 'ucf101':
    num_classes = 3
else:
    print('We only implemented hmdb and ucf datasets.')
    raise NotImplementedErrore
# os.path.abspathe 获取本目录的路径;os.path.dirname获取本目录的上一级路径
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]

# 生成runs文件夹下的路径(匹配save_dir_root下的文件夹数据),保存生成对应的tensorboard图
if resume_epoch != 0:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) if runs else 0
else:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(run_id))
modelName = 'C3D'  # Options: C3D or R2Plus1D or R3D
saveName = modelName + '-' + dataset

# 重点来了
def train_model(dataset=dataset, save_dir=save_dir, num_classes=num_classes, lr=lr,
                num_epochs=nEpochs, save_epoch=snapshot, useTest=useTest, test_interval=nTestInterval):
    """
        Args:
            num_classes (int): Number of classes in the data
            num_epochs (int, optional): Number of epochs to train for.
    """

    if modelName == 'C3D':
        # 获取C3D的预训练模型
        model = C3D_model.C3D(num_classes=num_classes, pretrained=True)
        # 获取对应模型的参数,下面用于优化
        train_params = [{'params': C3D_model.get_1x_lr_params(model), 'lr': lr},
                        {'params': C3D_model.get_10x_lr_params(model), 'lr': lr * 10}]
    elif modelName == 'R2Plus1D':
        model = R2Plus1D_model.R2Plus1DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
        train_params = [{'params': R2Plus1D_model.get_1x_lr_params(model), 'lr': lr},
                        {'params': R2Plus1D_model.get_10x_lr_params(model), 'lr': lr * 10}]
    elif modelName == 'R3D':
        model = R3D_model.R3DClassifier(num_classes=num_classes, layer_sizes=(2, 2, 2, 2))
        train_params = model.parameters()
    else:
        print('We only implemented C3D and R2Plus1D models.')
        raise NotImplementedError
    criterion = nn.CrossEntropyLoss()  # standard crossentropy loss for classification
    # train_params(模型的参数),进行优化
    optimizer = optim.SGD(train_params, lr=lr, momentum=0.9, weight_decay=5e-4)
    # 自适应学习率;
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,
                                          gamma=0.1)  # the scheduler divides the lr by 10 every 10 epochs
	# 加载模型参数,进行模型参数初始化和优化器参数初始化
    if resume_epoch == 0:
        print("Training {} from scratch...".format(modelName))
    else:
        checkpoint = torch.load(os.path.join(save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar'),
                       map_location=lambda storage, loc: storage)   # Load all tensors onto the CPU
        print("Initializing weights from: {}...".format(
            os.path.join(save_dir, 'models', saveName + '_epoch-' + str(resume_epoch - 1) + '.pth.tar')))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['opt_dict'])

    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
    model.to(device)
    criterion.to(device)

    log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)

    print('Training model on {} dataset...'.format(dataset))
    # 加载数据集
    train_dataloader = DataLoader(VideoDataset(dataset=dataset, split='train',clip_len=16), batch_size=1, shuffle=True, num_workers=4)
    val_dataloader   = DataLoader(VideoDataset(dataset=dataset, split='val',  clip_len=16), batch_size=1, num_workers=4)
    test_dataloader  = DataLoader(VideoDataset(dataset=dataset, split='test', clip_len=16), batch_size=1, num_workers=4)

    trainval_loaders = {'train': train_dataloader, 'val': val_dataloader}
    trainval_sizes = {x: len(trainval_loaders[x].dataset) for x in ['train', 'val']}
    test_size = len(test_dataloader.dataset)

    for epoch in range(resume_epoch, num_epochs):
        # each epoch has a training and validation step
        for phase in ['train', 'val']:
            start_time = timeit.default_timer()

            # reset the running loss and corrects
            running_loss = 0.0
            running_corrects = 0.0

            # set model to train() or eval() mode depending on whether it is trained
            # or being validated. Primarily affects layers such as BatchNorm or Dropout.
            if phase == 'train':
                # scheduler.step() is to be called once every epoch during training
                model.train()
            else:
                model.eval()

            for inputs, labels in tqdm(trainval_loaders[phase]):
                # move inputs and labels to the device the training is taking place on
                inputs = Variable(inputs, requires_grad=True).to(device)
                labels = Variable(labels).to(device)
                optimizer.zero_grad()

                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                probs = nn.Softmax(dim=1)(outputs)
                preds = torch.max(probs, 1)[1]
                loss = criterion(outputs, labels.long())

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / trainval_sizes[phase]
            epoch_acc = running_corrects.double() / trainval_sizes[phase]

            if phase == 'train':
                writer.add_scalar('data/train_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/train_acc_epoch', epoch_acc, epoch)
            else:
                writer.add_scalar('data/val_loss_epoch', epoch_loss, epoch)
                writer.add_scalar('data/val_acc_epoch', epoch_acc, epoch)

            print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(phase, epoch+1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")
        # 每50个周期保存一次模型
        if epoch % save_epoch == (save_epoch - 1):
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'opt_dict': optimizer.state_dict(),
            }, os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar'))
            print("Save model at {}\n".format(os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth.tar')))

        if useTest and epoch % test_interval == (test_interval - 1):
            model.eval()
            start_time = timeit.default_timer()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)

                with torch.no_grad():
                    outputs = model(inputs)
                probs = nn.Softmax(dim=1)(outputs)
                preds = torch.max(probs, 1)[1]
                loss = criterion(outputs, labels)

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / test_size
            epoch_acc = running_corrects.double() / test_size

            writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
            writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)

            print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(epoch+1, nEpochs, epoch_loss, epoch_acc))
            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")

    writer.close()


if __name__ == "__main__":
    train_model()

s.size(0)
running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / test_size
        epoch_acc = running_corrects.double() / test_size

        writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
        writer.add_scalar('data/test_acc_epoch', epoch_acc, epoch)

        print("[test] Epoch: {}/{} Loss: {} Acc: {}".format(epoch+1, nEpochs, epoch_loss, epoch_acc))
        stop_time = timeit.default_timer()
        print("Execution time: " + str(stop_time - start_time) + "\n")

writer.close()

if name == “main”:
train_model()


### 5.IPO
![在这里插入图片描述](https://img-blog.csdnimg.cn/20200925174210609.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2JldHRlcl9ib3k=,size_16,color_FFFFFF,t_70#pic_center)

  • 15
    点赞
  • 101
    收藏
    觉得还不错? 一键收藏
  • 18
    评论
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值