2021SC@SDUSC
read_data.py文件的分析
首先是对于自定义类Translator的分析:
class Translator:
def __init__(self, path, transform_type='BackTranslation'):
# Pre-processed German data
with open(path + 'de_1.pkl', 'rb') as f:
self.de = pickle.load(f)
# Pre-processed Russian data
with open(path + 'ru_1.pkl', 'rb') as f:
self.ru = pickle.load(f)
def __call__(self, ori, idx):
out1 = self.de[idx]
out2 = self.ru[idx]
return out1, out2, ori
如上述代码,首先我们定义的一个translator类来家在翻译的德语文本与俄文文本,为了方便起见,作者是预先将翻译好的文本加载到了pickle文件中,如代码中with open(pate+‘de_1.pkl’,‘rb’)。
其次是对get_data函数进行分析:
def get_data(data_path, n_labeled_per_class, unlabeled_per_class=5000, max_seq_len=256, model=‘bert-base-uncased’, train_aug=False):
此函数的用途是用于训练集、测试集、验证集的加载,data_path即为数据的地址,其分为了train.csv以及test.csv,其中n_labeled_per_class意为每个类中标记数据数目,而默认为标记数目为unlabeled_per_class=5000,最大序列长度为256,加载bert-base-uncased模型来进行训练,最后一个参数表示对于有标签的训练集数据不进行增强。
这里是利用BertTokenizer中的from_pretrained方法来加载基线模型:
tokenizer = BertTokenizer.from_pretrained(model)
训练数据与测试数据进行保存
train_df = pd.read_csv(data_path+'train.csv', header=None)
test_df = pd.read_csv(data_path+'test.csv', header=None)
在做标签时没我们仅仅利用数据主题并且去除了标题
train_labels = np.array([v-1 for v in train_df[0]])
train_text = np.array([v for v in train_df[2]])
test_labels = np.arra