pytorch学习4-FizzBuzz游戏


FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。

# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0
    
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

print(fizz_buzz_decode(1, fizz_buzz_encode(1)))
print(fizz_buzz_decode(2, fizz_buzz_encode(2)))
print(fizz_buzz_decode(5, fizz_buzz_encode(5)))
print(fizz_buzz_decode(12, fizz_buzz_encode(12)))
print(fizz_buzz_decode(15, fizz_buzz_encode(15)))

在这里插入图片描述

1.准备训练数据

# 我们首先定义模型的输入与输出(训练数据)
import numpy as np
import torch

NUM_DIGITS = 10

# Represent each input by an array of its binary digits.
def binary_encode(i,num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

tr_X = torch.Tensor([binary_encode(i,NUM_DIGITS) for i in range(101,2**NUM_DIGITS)])
tr_Y = torch.LongTensor([fizz_buzz_encode(i) for i in range(101,2**NUM_DIGITS)])

2.定义模型

NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS,NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN,4)
)

3.定义一个损失函数和一个优化算法

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.05)

4.训练数据

# Start training it
BATCH_SIZE = 128
for epoch in range(10000):
    for start in range(0,len(tr_X),BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = tr_X[start:end]
        batchY = tr_Y[start:end]
        
        y_pred = model(batchX)
        loss = loss_fn(y_pred,batchY)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Find loss on training data
    loss = loss_fn(model(tr_X),tr_Y).item()
    print('Epoch:',epoch,'loss:',loss)

在这里插入图片描述

5.预测数据

testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
with torch.no_grad():
    testY = model(testX)
predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for (i, x) in predictions])

6.测试数据准确率

print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])))
testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值