Exploring the TF-Hub CORD-19 Swivel Embeddings

cord_19_embeddings_keras

These embeddings were trained on the titles, authors, abstracts, body texts, and reference titles of articles in the CORD-19 dataset

In this colab we will:

  • Analyze semantically similar words in the embedding space
  • Train a classifier on the SciCite dataset using the CORD-19 embeddings

Analyze the embeddings:
Let’s start off by analyzing the embedding by calculating and plotting a correlation matrix between different terms. If the embedding learned to successfully capture the meaning of different words, the embedding vectors of semantically similar words should be close together. Let’s take a look at some COVID-19 related terms

We can see that the embedding successfully captured the meaning of the different terms. Each word is similar to the other words of its cluster (“coronavirus” highly correlates with “SARS” and “MERS”), while they are different from terms of other clusters (the similarity between “SARS” and “Spain” is close to 0)

SciCite: Citation Intent Classification:
This section shows how one can use the embedding for downstream tasks such as text classification. We’ll use the SciCite dataset from TensorFlow Datasets to classify citation intents in academic papers. Given a sentence with a citation from a academic paper, classify whether the main intent of the citation is as background information, use of methods, or comparing results.

We can see that for this random sample, the model predicts the correct label most of the times, indicating that it can embed scientific sentences pretty well

import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import tensorflow as tf

import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tqdm import trange


# Use the inner product between two embedding vectors as the similarity measure
def plot_correlation(labels, features):
    corr = np.inner(features, features)
    corr /= np.max(corr)
    sns.heatmap(corr, xticklabels=labels, yticklabels=labels)
    plt.show()


# Generate embeddings for some terms
queries = [
    # Related viruses
    'coronavirus', 'SARS', 'MERS',
    # Regions
    'Italy', 'Spain', 'Europe',
    # Symptoms
    'cough', 'fever', 'throat'
]

module = hub.load('https://tfhub.dev/tensorflow/cord-19/swivel-128d/3')
embeddings = module(queries)

plot_correlation(queries, embeddings)

builder = tfds.builder(name='scicite')
builder.download_and_prepare()
train_data, validation_data, test_data = builder.as_dataset(
    split=('train', 'validation', 'test'),
    as_supervised=True)

NUM_EXAMPLES = 10

TEXT_FEATURE_NAME = builder.info.supervised_keys[0]
LABEL_NAME = builder.info.supervised_keys[1]


def label2str(numeric_label):
    m = builder.info.features[LABEL_NAME].names
    return m[numeric_label]


data = next(iter(train_data.batch(NUM_EXAMPLES)))

print(pd.DataFrame({
    TEXT_FEATURE_NAME: [ex.numpy().decode('utf8') for ex in data[0]],
    LABEL_NAME: [label2str(x) for x in data[1]]
}))

EMBEDDING = 'https://tfhub.dev/tensorflow/cord-19/swivel-128d/3'
TRAINABLE_MODULE = False

hub_layer = hub.KerasLayer(EMBEDDING, input_shape=[],
                           dtype=tf.string, trainable=TRAINABLE_MODULE)

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(3))
model.summary()
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

EPOCHS = 35
BATCH_SIZE = 32

history = model.fit(train_data.shuffle(10000).batch(BATCH_SIZE),
                    epochs=EPOCHS,
                    validation_data=validation_data.batch(BATCH_SIZE),
                    verbose=1)

from matplotlib import pyplot as plt


def display_training_curves(training, validation, title, subplot):
    if subplot % 10 == 1:  # set up the subplots on the first call
        plt.subplots(figsize=(10, 10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model ' + title)
    ax.set_ylabel(title)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])


display_training_curves(history.history['accuracy'], history.history['val_accuracy'], 'accuracy', 211)
display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 212)

results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
    print('%s: %.3f' % (name, value))

prediction_dataset = next(iter(test_data.batch(20)))

prediction_texts = [ex.numpy().decode('utf8') for ex in prediction_dataset[0]]
prediction_labels = [label2str(x) for x in prediction_dataset[1]]

predictions = [
    label2str(x) for x in np.argmax(model.predict(prediction_texts), axis=-1)]

print(pd.DataFrame({
    TEXT_FEATURE_NAME: prediction_texts,
    LABEL_NAME: prediction_labels,
    'prediction': predictions
}))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值