半监督学习文本分类(四)

本文分析了`read_data.py`文件中的`Translator`类和`get_data`函数,用于处理德语文本和俄文文本的翻译及数据加载。`get_data`函数根据数据路径、标记数据数量、无标签数据数量、序列最大长度和预训练模型(如BERT)来划分训练集、测试集和验证集。数据集根据类别数量(如IMDB、DBPedia、Yahoo/AG News)进行选择,并使用随机函数确保数据划分的随机性。最终,函数返回训练集的标记和未标记数据,验证集以及标签数目。
摘要由CSDN通过智能技术生成

2021SC@SDUSC​​​​​​​

read_data.py文件的分析
首先是对于自定义类Translator的分析:

class Translator:

def __init__(self, path, transform_type='BackTranslation'):
    # Pre-processed German data
    with open(path + 'de_1.pkl', 'rb') as f:
        self.de = pickle.load(f)
    # Pre-processed Russian data
    with open(path + 'ru_1.pkl', 'rb') as f:
        self.ru = pickle.load(f)

def __call__(self, ori, idx):
    out1 = self.de[idx]
    out2 = self.ru[idx]
    return out1, out2, ori

如上述代码,首先我们定义的一个translator类来家在翻译的德语文本与俄文文本,为了方便起见,作者是预先将翻译好的文本加载到了pickle文件中,如代码中with open(pate+‘de_1.pkl’,‘rb’)。

其次是对get_data函数进行分析:

def get_data(data_path, n_labeled_per_class, unlabeled_per_class=5000, max_seq_len=256, model=‘bert-base-uncased’, train_aug=False):

此函数的用途是用于训练集、测试集、验证集的加载,data_path即为数据的地址,其分为了train.csv以及test.csv,其中n_labeled_per_class意为每个类中标记数据数目,而默认为标记数目为unlabeled_per_class=5000,最大序列长度为256,加载bert-base-uncased模型来进行训练,最后一个参数表示对于有标签的训练集数据不进行增强。

这里是利用BertTokenizer中的from_pretrained方法来加载基线模型:

tokenizer = BertTokenizer.from_pretrained(model)
训练数据与测试数据进行保存

train_df = pd.read_csv(data_path+'train.csv', header=None)
test_df = pd.read_csv(data_path+'test.csv', header=None)

在做标签时没我们仅仅利用数据主题并且去除了标题
train_labels = np.array([v-1 for v in train_df[0]])
train_text = np.array([v for v in train_df[2]])

test_labels = np.arra
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值