本篇博客是基于“Character-level Convolutional Networks for Text Classification”论文的模型创作的,同时也参考了同模型其他人的代码,经过个人的一些修改,希望能够更python,如果有不足的地方欢迎指点。(由于显示问题,个别代码的缩进可能会异常)
首先,为了方便统一管理参数,本人把目录路径、文件路径、提示信息、全局变量、超参放在一个xml里面,并且编写了my_profile.py来读取xml做初始化,由于该部分代码与主题无关,因此不加介绍。
为了准备数据,生成mini-batch,编写了pre_train来完成该工作。初始化函数如下:
class PreparationForTraining: def __init__(self): self.alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}\n" self.size_of_alphabet = len(self.alphabet) self.char_to_index = dict((c, i) for i, c in enumerate(self.alphabet)) self.index_to_char = dict((i, c) for i, c in enumerate(self.alphabet)) self.character_skipped = 0 self.sample_buffer = [] self.label_buffer = []
初始化工作主要把包括换行符的70个字符生成字符到索引和索引到字符的两个字典。
@staticmethod def load_csv(file_path, delimiter=',', quotechar='"'): data = [] with open(file_path, 'r') as csv_file: reader = csv.reader(csv_file, delimiter=delimiter, quotechar=quotechar) for class_index, title, text in reader: data.append((int(class_index), text.replace('\\n', '\n'))) return data
静态方法load_csv主要是读取训练文件,三个双引号中的内容分别是类别、标题和正文,我们只取类别和正文作为训练数据。
def x_one_hot_encode_by_alphabet(self, text, length0): assert isinstance(text, str) text = text.lower() x_input = np.zeros([length0, self.size_of_alphabet], dtype=np.float32) try: for index, char in enumerate(text): if char in self.char_to_index: x_input[index, self.char_to_index[char]] = 1 else: self.character_skipped += 1 pass except IndexError: pass logging.getLogger(__name__).debug(MyInit.message_about(__file__, 'encode_success', self.character_skipped)) return x_input
该方法主要是把正文内容变成one_hot编码,其中x_input默认是二维数组全0数组,其中行是截取长度,列为字符集的个数。
@staticmethod def y_one_hot_encode_by_alphabet(class_index, n_class): y_input = np.zeros([n_class]) y_input[int(class_index) - 1] = 1 return y_input
该方法主要是把类别变成one_hot编码
def build_batch(self, text_list, class_list, length0, n_class): x_batch = np.expand_dims( np.concatenate( tuple([np.expand_dims(self.x_one_hot_encode_by_alphabet(text, length0), axis=0) for text in text_list]) ), axis=-1 ) y_batch = np.concatenate( tuple([np.expand_dims( self.y_one_hot_encode_by_alphabet(class_index, n_class), axis=0) for class_index in class_list]) ) return x_batch, y_batch def get_batch_from_file(self, file_path, length0, n_class, batch_size): assert isinstance(file_path, str) if os.path.isfile(file_path): assert len(self.sample_buffer) == len(self.label_buffer)