FizzBuzz FizzBuzz是一个简单的小游戏。游戏规则如下:从1开始往上数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数,说buzz,当遇到15的倍数,就说fizzbuzz,其他情况下则正常数数。
我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。 def fizz_buzz_encode(i): if i % 15 == 0: return 3 elif i % 5 == 0: return 2 elif i % 3 == 0: return 1 else: return 0
def fizz_buzz_decode(i, prediction): return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
print(fizz_buzz_decode(1, fizz_buzz_encode(1))) print(fizz_buzz_decode(2, fizz_buzz_encode(2))) print(fizz_buzz_decode(5, fizz_buzz_encode(5))) print(fizz_buzz_decode(12, fizz_buzz_encode(12))) print(fizz_buzz_decode(15, fizz_buzz_encode(15)))import numpy as np import torch
NUM_DIGITS = 10
# Represent each input by an array of its binary digits. def binary_encode(i, num_digits): return np.array([i >> d & 1 for d in range(num_digits)])
trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)]) trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)]) # Define the model NUM_HIDDEN = 100 model = torch.nn.Sequential( torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN), torch.nn.ReLU(), torch.nn.Linear(NUM_HIDDEN, 4) )为了让我们的模型学会FizzBuzz这个游戏,我们需要定义一个损失函数,和一个优化算法。 这个优化算法会不断优化(降低)损失函数,使得模型的在该任务上取得尽可能低的损失值。 损失值低往往表示我们的模型表现好,损失值高表示我们的模型表现差。 由于FizzBuzz游戏本质上是一个分类问题,我们选用Cross Entropyy Loss函数。 优化函数我们选用Stochastic Gradient Descent。 optimizer = torch.optim.SGD(model.parameters(), lr = 0.05) # Start training it BATCH_SIZE = 128 for epoch in range(10000): for start in range(0, len(trX), BATCH_SIZE): end = start + BATCH_SIZE batchX = trX[start:end] batchY = trY[start:end]
y_pred = model(batchX) loss = loss_fn(y_pred, batchY)
optimizer.zero_grad() loss.backward() optimizer.step()
# Find loss on training data loss = loss_fn(model(trX), trY).item() print('Epoch:', epoch, 'Loss:', loss) testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)]) with torch.no_grad(): testY = model(testX) predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))
print([fizz_buzz_decode(i, x) for (i, x) in predictions]) testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)]) |
[Pytorch][转载]用pytorch解决简单的FizzBuzz问题
最新推荐文章于 2023-03-18 10:20:27 发布