【NLP】N2.Embeddingbag & Embedding

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

任务:加载txt文件,并用Embeddingbag&Embedding完成词嵌入
一、导入必要的包和数据
import pandas as pd
import torch
from torch import nn
import jieba as jb

data_dir = "/kaggle/input/text-data1/.txt"
data = pd.read_csv(data_dir, header=None)
data

结果:

 二、使用jieba分词:
tokenized_texts = [list(jb.cut(text)) for text in data[0]]
print(tokenized_texts)

输出分词结果:

三、制作词汇表:
word_index = {}
index_word = {}

for i, word in enumerate(set([word for text in tokenized_texts for word in text])):
    word_index[word] = i
    index_word[i] = word

vocab_size = len(word_index)

print(f"{vocab_size} / {word_index}")

词汇表大小:87

词汇表结果:

四、将文字转换为整数序列:
sequences1 = torch.tensor([word_index[word] for word in tokenized_texts[0]], dtype=torch.long)
sequences2 = torch.tensor([word_index[word] for word in tokenized_texts[1]], dtype=torch.long)
五、用embadding将句子转换为词嵌入向量:
embedding = nn.Embedding(vocab_size, 4)

embedding_sequences1 = embedding(sequences1)
embedding_sequences2 = embedding(sequences2)

print(embedding_sequences1)
print(embedding_sequences2)

embadding结果: 

tensor([[-2.2421,  0.4768,  1.5981, -0.1372],
        [ 0.2662,  0.6506, -0.3813,  0.3465],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [ 0.1479,  0.4894, -0.3237, -2.0757],
        [-0.5640, -1.1852,  0.3003,  1.1728],
        [ 0.4459, -0.7934, -0.1811,  1.3251],
        [-1.4543, -1.3130,  1.0980,  0.8700],
        [ 0.2644, -0.8209,  2.3713, -1.3098],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [ 1.9804, -0.0224, -1.0361,  0.1449],
        [-0.8875,  1.1185,  0.5886, -1.3458],
        [-0.0645,  1.1644, -0.7693, -0.4122],
        [ 0.7435, -0.4162,  1.2579, -1.7981],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 1.4095, -0.1287, -0.2783,  0.1564],
        [ 1.2758,  0.7741, -1.6845,  0.0374],
        [ 0.7291,  2.4709, -1.0184,  0.9997],
        [ 0.3255,  0.1719, -1.2049, -0.7580],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [ 0.4989,  0.2639, -1.3275,  1.4525],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [-1.6770,  0.5937, -0.4984, -0.0159],
        [-0.1202,  0.6479,  0.8591, -1.0248],
        [ 0.1757, -1.5443,  0.4225, -0.1653],
        [-0.4791, -0.4607, -1.1054, -1.2024],
        [-0.0297,  0.5840,  0.5421,  0.1201],
        [ 0.6148, -1.5387,  1.0232,  0.2355],
        [-1.2716, -0.9989,  1.1541, -1.3458],
        [ 0.7014,  0.5741, -0.2984,  0.7311],
        [-0.1746, -1.4164,  0.0412,  0.3107],
        [ 0.6628, -0.1708,  0.2175, -1.2665],
        [ 0.3255,  0.1719, -1.2049, -0.7580],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [-0.0645,  1.1644, -0.7693, -0.4122],
        [ 1.2490, -0.8628,  0.3805, -0.6328],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [-0.3621, -0.1040,  0.9854,  0.6555],
        [ 0.1479,  0.4894, -0.3237, -2.0757],
        [-0.9348, -0.9895,  1.6729, -0.0273],
        [ 1.2758,  0.7741, -1.6845,  0.0374],
        [ 0.4989,  0.2639, -1.3275,  1.4525],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 0.5002,  0.2031,  0.5255,  0.2732],
        [-0.9373, -0.8043, -0.7819,  0.6784],
        [ 0.1671, -0.9603, -0.9574, -1.0333],
        [-0.4921,  1.0120, -1.5466, -0.4126],
        [-0.0189, -0.6197,  0.2134,  0.6700],
        [ 0.3798,  1.1016,  0.9914, -1.7548],
        [ 1.6996,  1.1439,  0.3383, -1.5814],
        [ 0.6642, -0.8293, -0.0363, -0.6464],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [-0.2030,  0.4043, -0.4269,  2.0898],
        [-0.9348, -0.9895,  1.6729, -0.0273],
        [ 0.6830,  1.4578, -1.4581, -1.2760],
        [ 0.0985,  2.2091,  0.4579,  0.0645],
        [-0.8245,  0.7995,  1.2056, -0.8458],
        [ 0.2116,  0.3476,  0.3539, -0.6649],
        [-1.1390, -1.7131,  0.4298, -0.5274],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 1.8577,  0.3560,  0.3905,  0.1583],
        [-0.3970,  0.4146,  0.0542,  0.7691],
        [-0.8477, -0.1442, -1.3092,  0.3221],
        [-1.1390, -1.7131,  0.4298, -0.5274],
        [ 0.1671, -0.9603, -0.9574, -1.0333],
        [-0.5640, -1.1852,  0.3003,  1.1728],
        [-0.4246,  0.8655, -0.5637,  0.8559],
        [-0.9348, -0.9895,  1.6729, -0.0273],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.3495, -2.5871, -0.7908, -1.3156],
        [-0.4246,  0.8655, -0.5637,  0.8559],
        [ 1.7937,  1.3631, -0.0138, -0.6104],
        [-2.0114,  1.8538,  0.2406, -0.0205],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.0645,  1.1644, -0.7693, -0.4122]], grad_fn=<EmbeddingBackward0>)
tensor([[-0.3867, -0.0058, -0.7481,  0.4434],
        [-0.4892,  1.1163,  1.5090,  0.9548],
        [-0.3621, -0.1040,  0.9854,  0.6555],
        [ 0.4989,  0.2639, -1.3275,  1.4525],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 1.9659,  0.6739, -0.7447,  2.0967],
        [-0.8151,  0.5109, -0.5044,  0.2174],
        [ 1.7830, -0.6902,  0.3569,  0.8761],
        [-1.3112,  0.5679,  0.2629,  0.6237],
        [ 1.1008,  0.3024,  2.0047,  1.1738],
        [ 0.9640, -0.2232,  1.7577,  0.1672],
        [ 0.6152,  0.7374,  0.0465, -0.4756],
        [-0.1858,  0.1699, -0.6322, -0.4958],
        [ 0.1092, -1.2058, -1.2382,  1.7523],
        [ 1.5649, -1.1317,  0.6957,  0.4157],
        [ 0.9640, -0.2232,  1.7577,  0.1672],
        [ 1.7874,  1.4637, -1.2369,  0.4733],
        [-0.0645,  1.1644, -0.7693, -0.4122],
        [ 1.7830, -0.6902,  0.3569,  0.8761],
        [-1.3112,  0.5679,  0.2629,  0.6237],
        [ 1.1008,  0.3024,  2.0047,  1.1738],
        [ 0.9640, -0.2232,  1.7577,  0.1672],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.4225,  0.2887,  1.0175,  0.1776],
        [-0.3821,  0.1820,  0.1650, -0.6111],
        [-0.5640, -1.1852,  0.3003,  1.1728],
        [ 0.4481,  0.4344, -0.2955, -0.9457],
        [ 1.5302,  1.3884,  0.0516, -0.4676],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [ 0.8415, -0.1924, -0.7211, -0.2267],
        [-2.1970,  0.7031, -0.1178, -0.4499],
        [ 1.2758,  0.7741, -1.6845,  0.0374],
        [ 1.5492, -1.2600, -0.5016,  0.2411],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 0.2468,  0.4953,  0.2391, -0.4277],
        [-1.8871,  0.4976,  0.3504,  1.2424],
        [ 1.2758,  0.7741, -1.6845,  0.0374],
        [-0.6616, -1.0267, -0.0144,  1.3585],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.7786,  0.8371,  0.0190,  1.0264],
        [-0.6832, -0.7319, -0.8668,  0.6007],
        [ 0.1757, -1.5443,  0.4225, -0.1653],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 0.3108,  0.0142,  0.3639, -0.9848],
        [-0.6616, -1.0267, -0.0144,  1.3585],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.7786,  0.8371,  0.0190,  1.0264],
        [-0.6832, -0.7319, -0.8668,  0.6007],
        [ 0.6202, -0.6972,  0.9012,  0.0262],
        [-0.0645,  1.1644, -0.7693, -0.4122],
        [ 0.6907,  0.2946, -0.8080, -1.1677],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 1.5302,  1.3884,  0.0516, -0.4676],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [-0.2030,  0.4043, -0.4269,  2.0898],
        [ 0.5002,  0.2031,  0.5255,  0.2732],
        [ 1.5477, -2.2033, -0.3226,  0.1326],
        [-0.2892, -2.0367,  0.9881, -0.8522],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [-0.4246,  0.8655, -0.5637,  0.8559],
        [-0.9348, -0.9895,  1.6729, -0.0273],
        [ 0.0985,  2.2091,  0.4579,  0.0645],
        [-0.8245,  0.7995,  1.2056, -0.8458],
        [ 0.2116,  0.3476,  0.3539, -0.6649],
        [-1.1390, -1.7131,  0.4298, -0.5274],
        [-0.0645,  1.1644, -0.7693, -0.4122],
        [ 0.7435, -0.4162,  1.2579, -1.7981],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [ 1.4095, -0.1287, -0.2783,  0.1564],
        [ 0.3255,  0.1719, -1.2049, -0.7580],
        [-1.1748,  1.2276, -0.2742,  1.5316],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [ 0.0812,  1.0305,  1.2780,  0.4558],
        [-0.2672,  0.4423, -1.7335, -2.1201],
        [-1.6770,  0.5937, -0.4984, -0.0159],
        [-0.0591, -1.0665, -0.5901,  0.2102],
        [ 1.1762,  0.1607,  1.5622,  0.0729],
        [ 0.2144,  0.6271,  1.4468,  0.2856],
        [ 1.7830, -0.6902,  0.3569,  0.8761],
        [-1.3112,  0.5679,  0.2629,  0.6237],
        [ 1.1008,  0.3024,  2.0047,  1.1738],
        [ 0.9640, -0.2232,  1.7577,  0.1672],
        [ 0.9006,  0.0673,  0.6713, -1.2377]], grad_fn=<EmbeddingBackward0>)
六、embedding_bag的均值方法将句子转换成词嵌入向量:
embedding_bag = nn.EmbeddingBag(vocab_size, 30, mode = "mean")

input_sequences = torch.cat([sequences1,
                             sequences2])
offsets = torch.tensor([0, len(sequences1)], dtype=torch.long)

eb_seq = embedding_bag(input_sequences, offsets)
print(eb_seq)

embadding bag的结果:

七、总结:

通过这次练习加上上一次的实践,除整数序列以外又学习了词嵌入操作,理解了文字转换数值的一些方法,希望能在接下来的任务中能够熟练运用这些方法并加深理解吧。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值