import gc
import torch
import numpy as np
from torch import nn
from torch import optim
from torch. utils. data import DataLoader, TensorDataset
dtypes = torch. FloatTensor
device = torch. device( "cuda" if torch. cuda. is_available( ) else "cpu" )
device
device(type='cuda')
构造数据
sentences = [ "i like you" , "i love coffee" , "i hate milk" , "i think you" ]
word_list = " " . join( sentences) . split( )
print ( "*" * 80 )
print ( "word_list:" , word_list)
vocab = list ( set ( word_list) )
print ( "*" * 80 )
print ( "vocab:" , vocab)
word2idx = { n: i for i, n in enumerate ( vocab) }
print ( "*" * 80 )
print ( "word2idx:" , word2idx)
idx2word = { i: n for i, n in enumerate ( vocab) }
print ( "*" * 80 )
print ( "idx2word:" , idx2word)
n_class = len ( vocab)
print ( "*" * 80 )
print ( "n_class:" , n_class)
print ( "*" * 80 )
********************************************************************************
word_list: ['i', 'like', 'you', 'i', 'love', 'coffee', 'i', 'hate', 'milk', 'i', 'think', 'you']
********************************************************************************
vocab: ['you', 'like', 'think', 'milk', 'i', 'love', 'coffee', 'hate']
********************************************************************************
word2idx: {'you': 0, 'like': 1, 'think': 2, 'milk': 3, 'i': 4, 'love': 5, 'coffee': 6, 'hate': 7}
********************************************************************************
idx2word: {0: 'you', 1: 'like', 2: 'think', 3: 'milk', 4: 'i', 5: 'love', 6: 'coffee', 7: 'hate'}
********************************************************************************
n_class: 8
********************************************************************************
构建Dataset
batch_size = 2
time_step = 2
n_hidden = 5
def make_data ( sentences) :
inputs_ = [ ]
targets_ = [ ]
for sen in sentences:
word = sen. split( )
inputs = [ word2idx[ n] for n in word[ : - 1 ] ]
target = word2idx[ word[ - 1 ] ]
inputs_. append( np. eye( n_class) [ inputs] )
targets_. append( target)
return inputs_, targets_
inputs, targets= make_data( sentences)
inputs = torch. Tensor( inputs)
targets= torch. LongTensor( targets)
dataset = TensorDataset( inputs, targets)
train_loader = DataLoader( dataset, batch_size, shuffle= True )
<ipython-input-4-a2a7f25b9e13>:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ..\torch\csrc\utils\tensor_new.cpp:201.)
inputs = torch.Tensor(inputs)
for x, y in train_loader:
print ( "*" * 40 )
print ( x. shape, y. shape)
print ( "*" * 40 )
print ( x, y)
print ( "=" * 40 )
****************************************
torch.Size([2, 2, 8]) torch.Size([2])
****************************************
tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.]]]) tensor([0, 3])
========================================
****************************************
torch.Size([2, 2, 8]) torch.Size([2])
****************************************
tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.]]]) tensor([0, 6])
========================================
定义网络
class text_rnn ( nn. Module) :
def __init__ ( self) :
super ( text_rnn, self) . __init__( )
self. rnn = nn. RNN( n_class, n_hidden, batch_first = True )
self. fc= nn. Linear( n_hidden, n_class)
def forward ( self, x) :
out, hidden = self. rnn( x)
output = self. fc( out[ - 1 ] )
return output
model = text_rnn( ) . to( device)
criterion = nn. CrossEntropyLoss( )
optimizer = optim. Adam( model. parameters( ) , lr = 0.001 )
loss_all = [ ]
num_epochs = 500
for epoch in range ( num_epochs) :
train_loss = 0
train_num = 0
for step, ( x, y) in enumerate ( train_loader) :
x= x. to( device)
y = y. to( device)
z_hat= model. forward( x)
loss = criterion( z_hat, y)
loss. backward( )
optimizer. zero_grad( )
optimizer. step( )
train_loss+= loss. item( ) * len ( y)
train_num+= len ( y)
loss_all. append( train_loss/ train_num)
print ( f"Epoch:{epoch+1} Loss:{loss_all[-1]:0.8f}" )
del x, y, loss, train_loss, train_num
gc. collect( )
torch. cuda. empty_cache( )
Epoch:1 Loss:1.91730571
Epoch:2 Loss:2.00270271
Epoch:3 Loss:1.91730571
Epoch:4 Loss:1.76124740
Epoch:5 Loss:1.91730571
Epoch:6 Loss:1.84664440
Epoch:7 Loss:1.91730571
Epoch:8 Loss:1.87935627
Epoch:9 Loss:1.76124740
Epoch:10 Loss:1.87935627
Epoch:11 Loss:1.88459384
Epoch:12 Loss:2.00270271
Epoch:13 Loss:1.91730571
Epoch:14 Loss:1.84664440
Epoch:15 Loss:1.84664440
Epoch:16 Loss:1.76124740
Epoch:17 Loss:2.00270271
Epoch:18 Loss:1.84664440
Epoch:19 Loss:1.87935627
Epoch:20 Loss:1.88459384
Epoch:21 Loss:2.00270271
Epoch:22 Loss:1.87935627
Epoch:23 Loss:1.76124740
Epoch:24 Loss:2.00270271
Epoch:25 Loss:1.88459384
Epoch:26 Loss:2.00270271
Epoch:27 Loss:1.91730571
Epoch:28 Loss:1.87935627
Epoch:29 Loss:1.84664440
Epoch:30 Loss:2.00270271
Epoch:31 Loss:1.84664440
Epoch:32 Loss:1.76124740
Epoch:33 Loss:1.84664440
Epoch:34 Loss:1.76124740
Epoch:35 Loss:2.00270271
Epoch:36 Loss:1.76124740
Epoch:37 Loss:2.00270271
Epoch:38 Loss:1.84664440
Epoch:39 Loss:2.00270271
Epoch:40 Loss:2.00270271
Epoch:41 Loss:1.84664440
Epoch:42 Loss:1.91730571
Epoch:43 Loss:1.87935627
Epoch:44 Loss:1.84664440
Epoch:45 Loss:2.00270271
Epoch:46 Loss:1.91730571
Epoch:47 Loss:1.76124740
Epoch:48 Loss:2.00270271
Epoch:49 Loss:1.91730571
Epoch:50 Loss:2.00270271
Epoch:51 Loss:1.91730571
Epoch:52 Loss:1.84664440
Epoch:53 Loss:1.76124740
Epoch:54 Loss:1.91730571
Epoch:55 Loss:2.00270271
Epoch:56 Loss:2.00270271
Epoch:57 Loss:1.84664440
Epoch:58 Loss:1.87935627
Epoch:59 Loss:1.84664440
Epoch:60 Loss:1.88459384
Epoch:61 Loss:1.76124740
Epoch:62 Loss:1.88459384
Epoch:63 Loss:1.91730571
Epoch:64 Loss:1.84664440
Epoch:65 Loss:2.00270271
Epoch:66 Loss:2.00270271
Epoch:67 Loss:1.84664440
Epoch:68 Loss:1.84664440
Epoch:69 Loss:1.87935627
Epoch:70 Loss:1.88459384
Epoch:71 Loss:1.84664440
Epoch:72 Loss:1.76124740
Epoch:73 Loss:1.84664440
Epoch:74 Loss:1.87935627
Epoch:75 Loss:1.88459384
Epoch:76 Loss:2.00270271
Epoch:77 Loss:1.87935627
Epoch:78 Loss:1.88459384
Epoch:79 Loss:1.87935627
Epoch:80 Loss:1.88459384
Epoch:81 Loss:1.91730571
Epoch:82 Loss:2.00270271
Epoch:83 Loss:2.00270271
Epoch:84 Loss:1.84664440
Epoch:85 Loss:1.84664440
Epoch:86 Loss:1.88459384
Epoch:87 Loss:1.76124740
Epoch:88 Loss:1.88459384
Epoch:89 Loss:1.84664440
Epoch:90 Loss:1.87935627
Epoch:91 Loss:1.87935627
Epoch:92 Loss:1.91730571
Epoch:93 Loss:2.00270271
Epoch:94 Loss:1.87935627
Epoch:95 Loss:2.00270271
Epoch:96 Loss:1.87935627
Epoch:97 Loss:1.84664440
Epoch:98 Loss:1.76124740
Epoch:99 Loss:1.84664440
Epoch:100 Loss:2.00270271
Epoch:101 Loss:1.76124740
Epoch:102 Loss:1.84664440
Epoch:103 Loss:1.87935627
Epoch:104 Loss:1.91730571
Epoch:105 Loss:1.87935627
Epoch:106 Loss:1.84664440
Epoch:107 Loss:1.91730571
Epoch:108 Loss:1.87935627
Epoch:109 Loss:1.91730571
Epoch:110 Loss:1.91730571
Epoch:111 Loss:1.88459384
Epoch:112 Loss:1.87935627
Epoch:113 Loss:1.91730571
Epoch:114 Loss:2.00270271
Epoch:115 Loss:1.88459384
Epoch:116 Loss:1.84664440
Epoch:117 Loss:2.00270271
Epoch:118 Loss:1.91730571
Epoch:119 Loss:2.00270271
Epoch:120 Loss:2.00270271
Epoch:121 Loss:1.87935627
Epoch:122 Loss:1.87935627
Epoch:123 Loss:1.91730571
Epoch:124 Loss:1.91730571
Epoch:125 Loss:2.00270271
Epoch:126 Loss:1.84664440
Epoch:127 Loss:1.91730571
Epoch:128 Loss:1.88459384
Epoch:129 Loss:2.00270271
Epoch:130 Loss:1.91730571
Epoch:131 Loss:2.00270271
Epoch:132 Loss:1.88459384
Epoch:133 Loss:1.88459384
Epoch:134 Loss:1.91730571
Epoch:135 Loss:1.76124740
Epoch:136 Loss:1.87935627
Epoch:137 Loss:1.88459384
Epoch:138 Loss:2.00270271
Epoch:139 Loss:1.76124740
Epoch:140 Loss:1.87935627
Epoch:141 Loss:1.91730571
Epoch:142 Loss:1.91730571
Epoch:143 Loss:1.76124740
Epoch:144 Loss:1.88459384
Epoch:145 Loss:1.84664440
Epoch:146 Loss:1.84664440
Epoch:147 Loss:2.00270271
Epoch:148 Loss:1.84664440
Epoch:149 Loss:1.84664440
Epoch:150 Loss:1.88459384
Epoch:151 Loss:1.87935627
Epoch:152 Loss:1.84664440
Epoch:153 Loss:2.00270271
Epoch:154 Loss:1.76124740
Epoch:155 Loss:1.87935627
Epoch:156 Loss:1.91730571
Epoch:157 Loss:1.76124740
Epoch:158 Loss:1.91730571
Epoch:159 Loss:1.76124740
Epoch:160 Loss:1.84664440
Epoch:161 Loss:1.87935627
Epoch:162 Loss:1.87935627
Epoch:163 Loss:1.88459384
Epoch:164 Loss:1.76124740
Epoch:165 Loss:1.87935627
Epoch:166 Loss:1.84664440
Epoch:167 Loss:1.88459384
Epoch:168 Loss:1.91730571
Epoch:169 Loss:2.00270271
Epoch:170 Loss:1.76124740
Epoch:171 Loss:1.76124740
Epoch:172 Loss:1.87935627
Epoch:173 Loss:2.00270271
Epoch:174 Loss:1.88459384
Epoch:175 Loss:1.87935627
Epoch:176 Loss:1.91730571
Epoch:177 Loss:1.88459384
Epoch:178 Loss:1.84664440
Epoch:179 Loss:1.87935627
Epoch:180 Loss:2.00270271
Epoch:181 Loss:2.00270271
Epoch:182 Loss:1.84664440
Epoch:183 Loss:1.76124740
Epoch:184 Loss:1.87935627
Epoch:185 Loss:1.91730571
Epoch:186 Loss:1.91730571
Epoch:187 Loss:1.91730571
Epoch:188 Loss:1.88459384
Epoch:189 Loss:1.87935627
Epoch:190 Loss:1.87935627
Epoch:191 Loss:2.00270271
Epoch:192 Loss:2.00270271
Epoch:193 Loss:1.84664440
Epoch:194 Loss:1.76124740
Epoch:195 Loss:1.88459384
Epoch:196 Loss:2.00270271
Epoch:197 Loss:1.87935627
Epoch:198 Loss:1.88459384
Epoch:199 Loss:1.88459384
Epoch:200 Loss:1.76124740
Epoch:201 Loss:1.84664440
Epoch:202 Loss:2.00270271
Epoch:203 Loss:2.00270271
Epoch:204 Loss:1.91730571
Epoch:205 Loss:1.76124740
Epoch:206 Loss:1.76124740
Epoch:207 Loss:1.91730571
Epoch:208 Loss:1.76124740
Epoch:209 Loss:1.91730571
Epoch:210 Loss:1.88459384
Epoch:211 Loss:2.00270271
Epoch:212 Loss:1.91730571
Epoch:213 Loss:2.00270271
Epoch:214 Loss:2.00270271
Epoch:215 Loss:1.76124740
Epoch:216 Loss:1.84664440
Epoch:217 Loss:1.84664440
Epoch:218 Loss:1.88459384
Epoch:219 Loss:1.76124740
Epoch:220 Loss:1.88459384
Epoch:221 Loss:2.00270271
Epoch:222 Loss:1.88459384
Epoch:223 Loss:1.87935627
Epoch:224 Loss:1.87935627
Epoch:225 Loss:2.00270271
Epoch:226 Loss:2.00270271
Epoch:227 Loss:1.76124740
Epoch:228 Loss:1.76124740
Epoch:229 Loss:1.76124740
Epoch:230 Loss:1.91730571
Epoch:231 Loss:2.00270271
Epoch:232 Loss:2.00270271
Epoch:233 Loss:1.76124740
Epoch:234 Loss:1.76124740
Epoch:235 Loss:1.76124740
Epoch:236 Loss:1.87935627
Epoch:237 Loss:1.76124740
Epoch:238 Loss:1.84664440
Epoch:239 Loss:1.91730571
Epoch:240 Loss:1.87935627
Epoch:241 Loss:1.76124740
Epoch:242 Loss:2.00270271
Epoch:243 Loss:1.84664440
Epoch:244 Loss:1.76124740
Epoch:245 Loss:1.87935627
Epoch:246 Loss:2.00270271
Epoch:247 Loss:1.76124740
Epoch:248 Loss:1.87935627
Epoch:249 Loss:1.91730571
Epoch:250 Loss:1.91730571
Epoch:251 Loss:1.76124740
Epoch:252 Loss:1.87935627
Epoch:253 Loss:2.00270271
Epoch:254 Loss:1.76124740
Epoch:255 Loss:1.91730571
Epoch:256 Loss:1.84664440
Epoch:257 Loss:1.87935627
Epoch:258 Loss:1.87935627
Epoch:259 Loss:1.76124740
Epoch:260 Loss:1.87935627
Epoch:261 Loss:1.84664440
Epoch:262 Loss:1.91730571
Epoch:263 Loss:1.91730571
Epoch:264 Loss:1.87935627
Epoch:265 Loss:1.84664440
Epoch:266 Loss:1.88459384
Epoch:267 Loss:1.87935627
Epoch:268 Loss:1.91730571
Epoch:269 Loss:1.87935627
Epoch:270 Loss:1.88459384
Epoch:271 Loss:1.88459384
Epoch:272 Loss:1.87935627
Epoch:273 Loss:1.87935627
Epoch:274 Loss:1.88459384
Epoch:275 Loss:2.00270271
Epoch:276 Loss:2.00270271
Epoch:277 Loss:2.00270271
Epoch:278 Loss:1.88459384
Epoch:279 Loss:1.91730571
Epoch:280 Loss:2.00270271
Epoch:281 Loss:1.88459384
Epoch:282 Loss:2.00270271
Epoch:283 Loss:2.00270271
Epoch:284 Loss:2.00270271
Epoch:285 Loss:1.88459384
Epoch:286 Loss:2.00270271
Epoch:287 Loss:2.00270271
Epoch:288 Loss:1.87935627
Epoch:289 Loss:1.87935627
Epoch:290 Loss:1.87935627
Epoch:291 Loss:1.84664440
Epoch:292 Loss:1.76124740
Epoch:293 Loss:1.88459384
Epoch:294 Loss:2.00270271
Epoch:295 Loss:1.87935627
Epoch:296 Loss:1.84664440
Epoch:297 Loss:1.87935627
Epoch:298 Loss:1.88459384
Epoch:299 Loss:1.87935627
Epoch:300 Loss:1.87935627
Epoch:301 Loss:1.76124740
Epoch:302 Loss:1.87935627
Epoch:303 Loss:1.76124740
Epoch:304 Loss:1.88459384
Epoch:305 Loss:1.84664440
Epoch:306 Loss:1.91730571
Epoch:307 Loss:1.84664440
Epoch:308 Loss:1.87935627
Epoch:309 Loss:1.87935627
Epoch:310 Loss:2.00270271
Epoch:311 Loss:1.87935627
Epoch:312 Loss:1.76124740
Epoch:313 Loss:1.88459384
Epoch:314 Loss:1.87935627
Epoch:315 Loss:1.88459384
Epoch:316 Loss:1.76124740
Epoch:317 Loss:1.76124740
Epoch:318 Loss:1.76124740
Epoch:319 Loss:1.76124740
Epoch:320 Loss:1.76124740
Epoch:321 Loss:1.91730571
Epoch:322 Loss:1.76124740
Epoch:323 Loss:1.91730571
Epoch:324 Loss:2.00270271
Epoch:325 Loss:1.84664440
Epoch:326 Loss:1.88459384
Epoch:327 Loss:1.76124740
Epoch:328 Loss:1.84664440
Epoch:329 Loss:1.76124740
Epoch:330 Loss:1.88459384
Epoch:331 Loss:2.00270271
Epoch:332 Loss:1.91730571
Epoch:333 Loss:1.91730571
Epoch:334 Loss:2.00270271
Epoch:335 Loss:1.76124740
Epoch:336 Loss:1.91730571
Epoch:337 Loss:1.76124740
Epoch:338 Loss:2.00270271
Epoch:339 Loss:2.00270271
Epoch:340 Loss:2.00270271
Epoch:341 Loss:1.87935627
Epoch:342 Loss:1.91730571
Epoch:343 Loss:1.76124740
Epoch:344 Loss:1.76124740
Epoch:345 Loss:1.84664440
Epoch:346 Loss:1.91730571
Epoch:347 Loss:1.87935627
Epoch:348 Loss:1.84664440
Epoch:349 Loss:1.88459384
Epoch:350 Loss:2.00270271
Epoch:351 Loss:2.00270271
Epoch:352 Loss:1.91730571
Epoch:353 Loss:1.87935627
Epoch:354 Loss:1.76124740
Epoch:355 Loss:1.84664440
Epoch:356 Loss:2.00270271
Epoch:357 Loss:1.84664440
Epoch:358 Loss:1.76124740
Epoch:359 Loss:1.84664440
Epoch:360 Loss:1.84664440
Epoch:361 Loss:1.84664440
Epoch:362 Loss:1.88459384
Epoch:363 Loss:1.76124740
Epoch:364 Loss:1.84664440
Epoch:365 Loss:2.00270271
Epoch:366 Loss:1.84664440
Epoch:367 Loss:1.76124740
Epoch:368 Loss:1.91730571
Epoch:369 Loss:2.00270271
Epoch:370 Loss:1.84664440
Epoch:371 Loss:1.88459384
Epoch:372 Loss:1.84664440
Epoch:373 Loss:1.88459384
Epoch:374 Loss:2.00270271
Epoch:375 Loss:1.91730571
Epoch:376 Loss:1.76124740
Epoch:377 Loss:1.91730571
Epoch:378 Loss:1.84664440
Epoch:379 Loss:1.87935627
Epoch:380 Loss:1.76124740
Epoch:381 Loss:2.00270271
Epoch:382 Loss:1.87935627
Epoch:383 Loss:1.84664440
Epoch:384 Loss:1.88459384
Epoch:385 Loss:2.00270271
Epoch:386 Loss:2.00270271
Epoch:387 Loss:1.76124740
Epoch:388 Loss:1.88459384
Epoch:389 Loss:1.84664440
Epoch:390 Loss:1.91730571
Epoch:391 Loss:1.91730571
Epoch:392 Loss:1.91730571
Epoch:393 Loss:2.00270271
Epoch:394 Loss:1.84664440
Epoch:395 Loss:1.87935627
Epoch:396 Loss:1.76124740
Epoch:397 Loss:1.88459384
Epoch:398 Loss:1.88459384
Epoch:399 Loss:1.76124740
Epoch:400 Loss:1.88459384
Epoch:401 Loss:1.87935627
Epoch:402 Loss:1.87935627
Epoch:403 Loss:1.88459384
Epoch:404 Loss:1.91730571
Epoch:405 Loss:1.76124740
Epoch:406 Loss:1.91730571
Epoch:407 Loss:2.00270271
Epoch:408 Loss:2.00270271
Epoch:409 Loss:2.00270271
Epoch:410 Loss:1.87935627
Epoch:411 Loss:1.91730571
Epoch:412 Loss:1.88459384
Epoch:413 Loss:1.91730571
Epoch:414 Loss:1.84664440
Epoch:415 Loss:1.88459384
Epoch:416 Loss:1.91730571
Epoch:417 Loss:1.84664440
Epoch:418 Loss:2.00270271
Epoch:419 Loss:1.91730571
Epoch:420 Loss:1.76124740
Epoch:421 Loss:1.87935627
Epoch:422 Loss:1.88459384
Epoch:423 Loss:2.00270271
Epoch:424 Loss:1.84664440
Epoch:425 Loss:1.76124740
Epoch:426 Loss:1.76124740
Epoch:427 Loss:1.76124740
Epoch:428 Loss:1.88459384
Epoch:429 Loss:1.88459384
Epoch:430 Loss:1.76124740
Epoch:431 Loss:1.91730571
Epoch:432 Loss:1.87935627
Epoch:433 Loss:2.00270271
Epoch:434 Loss:1.84664440
Epoch:435 Loss:1.91730571
Epoch:436 Loss:1.88459384
Epoch:437 Loss:2.00270271
Epoch:438 Loss:1.91730571
Epoch:439 Loss:1.84664440
Epoch:440 Loss:1.84664440
Epoch:441 Loss:1.76124740
Epoch:442 Loss:1.76124740
Epoch:443 Loss:1.88459384
Epoch:444 Loss:1.76124740
Epoch:445 Loss:2.00270271
Epoch:446 Loss:1.76124740
Epoch:447 Loss:1.87935627
Epoch:448 Loss:1.84664440
Epoch:449 Loss:1.84664440
Epoch:450 Loss:1.88459384
Epoch:451 Loss:1.88459384
Epoch:452 Loss:1.76124740
Epoch:453 Loss:1.87935627
Epoch:454 Loss:2.00270271
Epoch:455 Loss:2.00270271
Epoch:456 Loss:1.91730571
Epoch:457 Loss:1.88459384
Epoch:458 Loss:1.87935627
Epoch:459 Loss:1.76124740
Epoch:460 Loss:1.87935627
Epoch:461 Loss:1.88459384
Epoch:462 Loss:1.84664440
Epoch:463 Loss:1.87935627
Epoch:464 Loss:1.76124740
Epoch:465 Loss:1.76124740
Epoch:466 Loss:2.00270271
Epoch:467 Loss:1.88459384
Epoch:468 Loss:2.00270271
Epoch:469 Loss:1.91730571
Epoch:470 Loss:1.84664440
Epoch:471 Loss:1.91730571
Epoch:472 Loss:1.76124740
Epoch:473 Loss:1.88459384
Epoch:474 Loss:1.88459384
Epoch:475 Loss:1.87935627
Epoch:476 Loss:1.91730571
Epoch:477 Loss:1.84664440
Epoch:478 Loss:1.91730571
Epoch:479 Loss:1.87935627
Epoch:480 Loss:1.84664440
Epoch:481 Loss:1.88459384
Epoch:482 Loss:2.00270271
Epoch:483 Loss:1.84664440
Epoch:484 Loss:1.88459384
Epoch:485 Loss:1.91730571
Epoch:486 Loss:1.88459384
Epoch:487 Loss:1.91730571
Epoch:488 Loss:1.76124740
Epoch:489 Loss:2.00270271
Epoch:490 Loss:2.00270271
Epoch:491 Loss:1.76124740
Epoch:492 Loss:1.88459384
Epoch:493 Loss:1.84664440
Epoch:494 Loss:1.88459384
Epoch:495 Loss:1.84664440
Epoch:496 Loss:2.00270271
Epoch:497 Loss:1.87935627
Epoch:498 Loss:1.76124740
Epoch:499 Loss:1.88459384
Epoch:500 Loss:1.84664440
test_text = [ sen. split( ) [ : 2 ] for sen in sentences]
test_text
[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'think']]
tests = [ ]
for word in test_text[ : ] :
inputs = [ word2idx[ n] for n in word]
tests. append( np. eye( n_class) [ inputs] )
tests = torch. tensor( tests, dtype = torch. float32) . to( device)
test_dataset = TensorDataset( tests)
test_loader = DataLoader( test_dataset, batch_size, shuffle= True )
print ( "done!" )
done!
for step, x in enumerate ( test_loader) :
print ( step, x)
0 [tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.]]], device='cuda:0')]
1 [tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1.]],
[[0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0.]]], device='cuda:0')]
model. eval ( )
for step, x in enumerate ( test_loader) :
predict = model( x[ 0 ] ) . data. max ( 1 , keepdim= True ) [ 1 ]
print ( "->" , [ idx2word[ n. item( ) ] for n in predict. cpu( ) . data. squeeze( ) ] )
-> ['you', 'like']
-> ['you', 'milk']