import gc
import torch
import numpy as np
from torch import nn, optim
import matplotlib. pyplot as plt
from torch. utils. data import DataLoader, TensorDataset
device = torch. device( "cuda" if torch. cuda. is_available( ) else "cpu" )
device
device(type='cuda')
定义数据
sentences = [ "jack like dog" , "jack like cat" , "jack like animal" ,
"dog cat animal" , "banana apple cat dog like" , "dog fish milk like" ,
"dog cat animal like" , "jack like apple" , "apple like" , "jack like banana" ,
"apple banana jack movie book music like" ]
word_sequence = " " . join( sentences) . split( )
vocab = list ( set ( word_sequence) )
word2idx = { w: i for i, w in enumerate ( vocab) }
print ( "*" * 85 )
print ( "word_sequence:" , word_sequence)
print ( "*" * 85 )
print ( "vocab:" , vocab)
print ( "*" * 85 )
print ( "word2idx:" , word2idx)
*************************************************************************************
word_sequence: ['jack', 'like', 'dog', 'jack', 'like', 'cat', 'jack', 'like', 'animal', 'dog', 'cat', 'animal', 'banana', 'apple', 'cat', 'dog', 'like', 'dog', 'fish', 'milk', 'like', 'dog', 'cat', 'animal', 'like', 'jack', 'like', 'apple', 'apple', 'like', 'jack', 'like', 'banana', 'apple', 'banana', 'jack', 'movie', 'book', 'music', 'like']
*************************************************************************************
vocab: ['dog', 'cat', 'movie', 'jack', 'fish', 'milk', 'music', 'book', 'apple', 'banana', 'like', 'animal']
*************************************************************************************
word2idx: {'dog': 0, 'cat': 1, 'movie': 2, 'jack': 3, 'fish': 4, 'milk': 5, 'music': 6, 'book': 7, 'apple': 8, 'banana': 9, 'like': 10, 'animal': 11}
数据预处理
batch_size = 4
embedding_size= 2
window = 2
vocab_size = len ( vocab)
skip_grams = [ ]
for idx in range ( window, len ( word_sequence) - window) :
center = word2idx[ word_sequence[ idx] ]
context_idx = list ( range ( idx- window, idx) ) + list ( range ( idx+ 1 , idx+ window+ 1 ) )
context = [ word2idx[ word_sequence[ i] ] for i in context_idx]
for w in context:
skip_grams. append( [ center, w] )
def make_data ( skip_grams) :
input_data = [ ]
output_data = [ ]
for i in range ( len ( skip_grams) ) :
input_data. append( np. eye( vocab_size) [ skip_grams[ i] [ 0 ] ] )
output_data. append( skip_grams[ i] [ 1 ] )
return input_data, output_data
input_data, output_data = make_data( skip_grams)
input_data= torch. tensor( input_data, dtype= torch. float32)
output_data = torch. tensor( output_data, dtype= torch. long )
dataset = TensorDataset( input_data, output_data)
train_loader = DataLoader( dataset, batch_size, shuffle = True )
构建模型
class word2vec_ ( nn. Module) :
def __init__ ( self) :
super ( word2vec_, self) . __init__( )
self. w = nn. Parameter( torch. randn( vocab_size, embedding_size) . type ( torch. float32) )
self. v = nn. Parameter( torch. randn( embedding_size, vocab_size) . type ( torch. float32) )
def forward ( self, x) :
hidden = torch. matmul( x, self. w)
output = torch. matmul( hidden, self. v )
return output
model = word2vec_( ) . to( device)
criterion = nn. CrossEntropyLoss( )
optimizer = optim. Adam( model. parameters( ) , lr= 0.0001 )
num_epochs = 100
loss_all = [ ]
for epoch in range ( num_epochs) :
train_loss = 0
train_num = 0
for step, ( x, y) in enumerate ( train_loader) :
x = x. to( device)
y = y. to( device)
z_hat = model. forward( x)
loss= criterion( z_hat, y)
loss. backward( )
optimizer. zero_grad( )
optimizer. step( )
train_loss += loss. item( ) * len ( y)
train_num+= len ( y)
loss_all. append( train_loss/ train_num)
print ( f"Epoch: { epoch+ 1 } Loss: { loss_all[ - 1 ] : 0.8f } " )
del x, y, loss, train_loss, train_num
gc. collect( )
torch. cuda. empty_cache( )
Epoch:1 Loss:3.78907597
Epoch:2 Loss:3.78907595
Epoch:3 Loss:3.78907597
Epoch:4 Loss:3.78907596
Epoch:5 Loss:3.78907599
Epoch:6 Loss:3.78907598
Epoch:7 Loss:3.78907598
Epoch:8 Loss:3.78907598
Epoch:9 Loss:3.78907598
Epoch:10 Loss:3.78907600
Epoch:11 Loss:3.78907598
Epoch:12 Loss:3.78907597
Epoch:13 Loss:3.78907598
Epoch:14 Loss:3.78907599
Epoch:15 Loss:3.78907598
Epoch:16 Loss:3.78907599
Epoch:17 Loss:3.78907599
Epoch:18 Loss:3.78907596
Epoch:19 Loss:3.78907598
Epoch:20 Loss:3.78907598
Epoch:21 Loss:3.78907597
Epoch:22 Loss:3.78907598
Epoch:23 Loss:3.78907599
Epoch:24 Loss:3.78907597
Epoch:25 Loss:3.78907599
Epoch:26 Loss:3.78907596
Epoch:27 Loss:3.78907596
Epoch:28 Loss:3.78907598
Epoch:29 Loss:3.78907597
Epoch:30 Loss:3.78907598
Epoch:31 Loss:3.78907598
Epoch:32 Loss:3.78907599
Epoch:33 Loss:3.78907597
Epoch:34 Loss:3.78907596
Epoch:35 Loss:3.78907598
Epoch:36 Loss:3.78907597
Epoch:37 Loss:3.78907598
Epoch:38 Loss:3.78907599
Epoch:39 Loss:3.78907599
Epoch:40 Loss:3.78907598
Epoch:41 Loss:3.78907598
Epoch:42 Loss:3.78907602
Epoch:43 Loss:3.78907597
Epoch:44 Loss:3.78907597
Epoch:45 Loss:3.78907599
Epoch:46 Loss:3.78907598
Epoch:47 Loss:3.78907596
Epoch:48 Loss:3.78907597
Epoch:49 Loss:3.78907597
Epoch:50 Loss:3.78907598
Epoch:51 Loss:3.78907597
Epoch:52 Loss:3.78907596
Epoch:53 Loss:3.78907595
Epoch:54 Loss:3.78907596
Epoch:55 Loss:3.78907596
Epoch:56 Loss:3.78907598
Epoch:57 Loss:3.78907598
Epoch:58 Loss:3.78907600
Epoch:59 Loss:3.78907599
Epoch:60 Loss:3.78907598
Epoch:61 Loss:3.78907596
Epoch:62 Loss:3.78907597
Epoch:63 Loss:3.78907597
Epoch:64 Loss:3.78907598
Epoch:65 Loss:3.78907597
Epoch:66 Loss:3.78907599
Epoch:67 Loss:3.78907598
Epoch:68 Loss:3.78907596
Epoch:69 Loss:3.78907599
Epoch:70 Loss:3.78907598
Epoch:71 Loss:3.78907597
Epoch:72 Loss:3.78907597
Epoch:73 Loss:3.78907596
Epoch:74 Loss:3.78907599
Epoch:75 Loss:3.78907596
Epoch:76 Loss:3.78907596
Epoch:77 Loss:3.78907598
Epoch:78 Loss:3.78907598
Epoch:79 Loss:3.78907596
Epoch:80 Loss:3.78907595
Epoch:81 Loss:3.78907598
Epoch:82 Loss:3.78907597
Epoch:83 Loss:3.78907599
Epoch:84 Loss:3.78907596
Epoch:85 Loss:3.78907598
Epoch:86 Loss:3.78907598
Epoch:87 Loss:3.78907598
Epoch:88 Loss:3.78907598
Epoch:89 Loss:3.78907597
Epoch:90 Loss:3.78907598
Epoch:91 Loss:3.78907597
Epoch:92 Loss:3.78907597
Epoch:93 Loss:3.78907597
Epoch:94 Loss:3.78907597
Epoch:95 Loss:3.78907597
Epoch:96 Loss:3.78907596
Epoch:97 Loss:3.78907597
Epoch:98 Loss:3.78907597
Epoch:99 Loss:3.78907597
Epoch:100 Loss:3.78907597
可视化
for i, label in enumerate ( vocab) :
W, WT = model. parameters( )
x, y = float ( W[ i] [ 0 ] ) , float ( W[ i] [ 1 ] )
plt. scatter( x, y)
plt. annotate( label, xy= ( x, y) , xytext= ( 5 , 2 ) , textcoords= 'offset points' , ha= 'right' , va= 'bottom' )
plt. show( )