【关系抽取-mre-in-one-pass】加载数据(二)

接上一节加载数据(一)

上一节我们说到了

convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer)

这个函数,里面又分别调用了:

loc, mas, e1_mas, e2_mas = prepare_extra_data(mapping_a, example.locations, FLAGS.max_distance)

而在prepare_extra_data里面调用了两个函数:

convert_entity_row(mapping, e, max_distance)
find_lo_hi(mapping, lo)

我们一步步从prepare_extra_data里面看起:

  • 一开始就定义了4个数组:
  res = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  mas = np.zeros([FLAGS.max_seq_length, FLAGS.max_seq_length], dtype=np.int8)
  
  e1_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)
  e2_mas = np.zeros([FLAGS.max_num_relations, FLAGS.max_seq_length], dtype=np.int8)

先总体对这些是什么有个大概的了解:
(1)res:存储的是相对位置,是一个[128,128]的数组,这里的128是句子的最大长度。这个数组记录的是实体和其它词之间的相对位置。
(2)mas:存储的是实体的mask矩阵,也就是每个句子中实体出现的位置就是1,其它的就是0,也是一个[128,128]的数组
(3)e1_mas:在每一对关系中实体1的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。
(4)e2_mas:在每一对关系中实体2的掩码矩阵,维度是[12,128],其中12是设置的最大的关系种类数。

  • 得到每一个关系的实体集合
entities = set()
for loc in locs:
    entities.add(loc[0])
    entities.add(loc[1])
  • 接下来是关键了
  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # col
    res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
    mas[1:, sub_lo1:sub_hi2+1] = 1

我们先看下输出:

example.text_a = a large database . Traditional information retrieval techniques use a histogram of keywords as the document representation but oral communication may offer additional indices such as the time and is shown on a large database of TV shows . Emotions and other indices
tokens_a = ['a', 'large', 'database', '.', 'traditional', 'information', 'retrieval', 'techniques', 'use', 'a', 'his', '##to', '##gram', 'of', 'key', '##words', 'as', 'the', 'document', 'representation', 'but', 'oral', 'communication', 'may', 'offer', 'additional', 'indices', 'such', 'as', 'the', 'time', 'and', 'is', 'shown', 'on', 'a', 'large', 'database', 'of', 'tv', 'shows', '.', 'emotions', 'and', 'other', 'indices']
mapping_a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 11, 12, 13, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]
example.locations = [((13, 13), (6, 8)), ((19, 20), (24, 24)), ((37, 38), (35, 35))]
entities = {(6, 8), (13, 13), (35, 35), (37, 38), (24, 24), (19, 20)}

对于每一个实体的位置,调用了relative_position, _ =convert_entity_row(mapping, e, max_distance),这个函数:

def convert_entity_row(mapping, loc, max_distance):
  """
  convert an entity span(lo,hi) to a relative distance vector of shape [max_seq_length]
  """
  lo, hi = loc
  res = [max_distance] * FLAGS.max_seq_length
  mas = [0] * FLAGS.max_seq_length
  for i in range(FLAGS.max_seq_length):
    if i < len(mapping):
      val = mapping[i]
      if val < lo - max_distance:
        res[i] = max_distance
      elif val < lo:
        res[i] = lo - val
      elif val <= hi:
        res[i] = 0
        mas[i] = 1
      elif val <= hi + max_distance:
        res[i] = val - hi + max_distance
      else:
        res[i] = 2 * max_distance
    else:
      res[i] = 2 * max_distance
  return res, mas

的输出是:

lo = 6
hi = 8
res = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
mas = [0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
relative_position = [4, 4, 4, 3, 2, 1, 0, 0, 0, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]

设置最大距离为4。在res中,对于实体而言,其相对位置为0,当实体左边的字和实体左边边界的距离小于定义的最大距离时,值就是距离值,否则左边的就都是最大距离值。同理右边也是这样,只不过是从最大值开始,到最大值的两倍结束,需要注意的是由于是wordpiece拆分的,对于一个单词而言,如果拆分成了几个,那么他们的位置是一致的,比如上面的7,7,7。如果不好理解的话,直接看上面的结果就能理解了。
对于:

def find_lo_hi(mapping, value):
  """
  find the boundary of a value in a list
  will return (0,0) if no such value in the list
  """
  try:
    lo = mapping.index(value)
    hi = min(len(mapping) - 1 - mapping[::-1].index(value), FLAGS.max_seq_length)
    return (lo, hi)
  except:
    return (0,0)

这个而言,由于我们会进行wordpiece的拆分,因此实体在分词后的索引有可能是变换的,因此对于hi,我们要反向索引。

  • 接着就是将位置信息用矩阵的形式表现,也就是下面的两段代码:
  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # col
    res[:, sub_lo1:sub_hi2+1] = np.expand_dims(relative_position, -1)
    mas[1:, sub_lo1:sub_hi2+1] = 1

  for e in entities:
    (lo, hi) = e
    relative_position, _ = convert_entity_row(mapping, e, max_distance)
    sub_lo1, sub_hi1 = find_lo_hi(mapping, lo)
    sub_lo2, sub_hi2 = find_lo_hi(mapping, hi)
    if sub_lo1 == 0 and sub_hi1 == 0:
      continue
    if sub_lo2 == 0 and sub_hi2 == 0:
      continue
    # row
    res[sub_lo1:sub_hi2+1, :] = relative_position
    mas[sub_lo1:sub_hi2+1, 1:] = 1

结果是这样的:

[[0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 4 4 4 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 3 3 3 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 2 2 2 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [4 4 4 3 2 1 0 0 0 5 6 7 7 7 8 ... 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 [0 0 0 0 0 0 5 5 5 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 6 6 6 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 7 7 7 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 ...
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 8 8 8 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]

最后就是实体的掩码矩阵了:

  for idx, (e1,e2) in enumerate(locs):
    # e1
    (lo, hi) = e1
    _, mask = convert_entity_row(mapping, e1, max_distance)
    e1_mas[idx] = mask
    # e2
    (lo, hi) = e2
    _, mask = convert_entity_row(mapping, e2, max_distance)
    e2_mas[idx] = mask

结果:

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 ...
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
  • 回到convert_single_example函数中来:
  label_id = [label_map[label] for label in example.labels]
  label_id = label_id + [0] * (FLAGS.max_num_relations - len(label_id))
  cls_mask = [1] * example.num_relations + [0] * (FLAGS.max_num_relations - example.num_relations)

这里定义了一个最大关系数量:12。先看结果:

labels: 5 5 2 0 0 0 0 0 0 0 0 0
cls_mask:[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]

也就是一句话中的句子有多种关系的实体。
最终将这些信息包装为InputFeatures类并返回。

  • 回到file_based_convert_examples_to_features函数:
def file_based_convert_examples_to_features(
    examples, label_list, max_seq_length, tokenizer, output_file):
  """Convert a set of `InputExample`s to a TFRecord file."""

  writer = tf.python_io.TFRecordWriter(output_file)

  for (ex_index, example) in enumerate(examples):
    if ex_index % 10000 == 0:
      tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

    feature = convert_single_example(ex_index, example, label_list,
                                     max_seq_length, tokenizer)

    def create_int_feature(values):
      f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
      return f

    features = collections.OrderedDict()
    features["input_ids"] = create_int_feature(feature.input_ids)
    features["input_mask"] = create_int_feature(feature.input_mask)
    features["segment_ids"] = create_int_feature(feature.segment_ids)
    features["loc"] = create_int_feature(feature.loc)
    features["mas"] = create_int_feature(feature.mas)
    features["e1_mas"] = create_int_feature(feature.e1_mas)
    features["e2_mas"] = create_int_feature(feature.e2_mas)
    features["cls_mask"] = create_int_feature(feature.cls_mask)
    features["label_ids"] = create_int_feature(feature.label_id)

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))
    writer.write(tf_example.SerializeToString())
  writer.close()

也没什么好说的,转换成tensorflow中训练所需的张量后存储起来就行了。
至此,mre-in-one-pass的数据处理部分就完了。

参考代码:https://sourcegraph.com/github.com/helloeve/mre-in-one-pass/-/blob/run_classifier.py#L550

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xiximayou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值