# 导入需要的库
import os
import time
import math
import json
import joblib
import random
import argparse
import numpy as np
import tensorflow as tf
# 导入一些工具函数和模块
from tqdm import tqdm # 进度条
from functools import partial # 高阶函数工具
from sklearn.utils import shuffle # 打乱数据
from sklearn.metrics import accuracy_score # 计算准确率
# 从自定义模块中导入函数和类
from opt import adam, warmup_cosine, warmup_linear, warmup_constant # 优化器和学习率调度器
from datasets import rocstories # 数据集加载函数
from analysis import rocstories as rocstories_analysis # 数据分析函数
from text_utils import TextEncoder # 文本编码工具
from utils import encode_dataset, flatten, iter_data, find_trainable_variables, convert_gradient_to_tensor, shape_list, ResultLogger, assign_to_gpu, average_grads, make_path # 一些实用函数
# 定义激活函数
def gelu(x):
# GELU激活函数
return 0.5*x*(1+tf.tanh(math.sqrt(2/math.pi)*(x+0.044715*tf.pow(x, 3))))
def swish(x):
# Swish激活函数
return x*tf.nn.sigmoid(x)
# 定义优化器、激活函数和学习率调度器的字典
opt_fns = {'adam':adam}
act_fns = {'relu':tf.nn.relu, 'swish':swish, 'gelu':gelu}
lr_schedules = {'warmup_cosine':warmup_cosine, 'warmup_linear':warmup_linear, 'warmup_constant':warmup_constant}
# 标准化函数
def _norm(x, g=None, b=None, e=1e-5, axis=[1]):
# ...(函数内部实现省略)
def norm(x, scope, axis=[-1]):
# ...(函数内部实现省略)
# Dropout函数
def dropout(x, pdrop, train):
# ...(函数内部实现省略)
# 注意力权重掩码函数
def mask_attn_weights(w):
# ...(函数内部实现省略)
# 注意力机制函数
def _attn(q, k, v, train=False, scale=False):
# ...(函数内部实现省略)
# 分割和合并状态函数
def split_states(x, n):
# ...(函数内部实现省略)
def merge_states(x):
# ...(函数内部实现省略)
# 分割和合并头函数
def split_heads(x, n, k=False):
# ...(函数内部实现省略)
def merge_heads(x):
# ...(函数内部实现省略)
# 一维卷积函数
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), pad='VALID', train=False):
# ...(函数内部实现省略)
# 注意力块函数
def attn(x, scope, n_state, n_head, train=False, scale=False):
# ...(函数内部实现省略)
# 多层感知机函数
def mlp(x, scope, n_state, train=False):
# ...(函数内部实现省略)
# Transformer块函数
def block(x, scope, train=False, scale=False):
# ...(函数内部实现省略)
# 嵌入函数
def embed(X, we):
# ...(函数内部实现省略)
# 分类函数
def clf(x, ny, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), train=False):
# ...(函数内部实现省略)
# 主模型函数
def model(X, M, Y, train=False, reuse=False):
# ...(函数内部实现省略)
# 多GPU训练函数
def mgpu_train(*xs):
# ...(函数内部实现省略)
# 多GPU预测函数
def mgpu_predict(*xs):
# ...(函数内部实现省略)
# 数据转换函数
def transform_roc(X1, X2, X3):
# ...(函数内部实现省略)
# 迭代应用函数(用于训练和验证)
def iter_apply(Xs, Ms, Ys):
# ...(函数内部实现省略)
# 迭代预测函数
def iter_predict(Xs, Ms):
# ...(函数内部实现省略)
# 保存模型参数函数
def save(path):
# ...(函数内部实现省略)
# 记录日志函数
def log():
# ...(函数内部实现省略)
# 主函数(命令行参数解析和模型训练/预测)
if __name__ == '__main__':
# ...(参数解析和模型初始化省略)
# 训练循环
for i in range(n_iter):
# ...(训练迭代和日志记录省略)
# 提交预测
if submit:
# ...(加载最佳参数,进行预测和分析省略)
# 一些辅助函数和变量定义(如argmax,pred_fns,filenames等)
# ...(省略)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.