#定义fizzbuzz游戏
def fizzbuzz_encode(i):
if i%15==0:return 3
if i%5==0:return 2
if i%3==0:return 1
else:return 0
def fizzbuzz_decode(i,prediction):
return [str(i),'fizz','buzz','fizzbuzz'][prediction]
def start(i):
return fizzbuzz_decode(i,fizzbuzz_encode(i))
``````python
#定义pyTorch网络
#定义测试数据
import numpy as np
import torch
def binary_encode(i,num_digits):
return np.array([i>>d&1 for i in range(len(num_digits))])[::-1]
trX = torch.tensor([binary_encode(i,10) for i in range(101,2**10)],dtype=torch.float32)
trY = torch.tensor([fizzbuzz_encode(i) for i in range(101,2**10)],dtype=torch.long)
#定义二层网络
model = torch.nn.Sequential(
torch.nn.Linear(10,100),
torch.nn.ReLU(),
torch.nn.Linear(100,4)
)
#定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
learning_rate = 1e-4
#优化函数
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
BACTH_SIZE = 128
for epoch in range(100):
for start in range(0,len(trX),BATCH_SIZE):
end = start + BATCH_SIZE
batchX = trX[start:end]
batchY = trY[start:end]
preY = model(batchX)
loss = loss_fn(preY,batchY)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if start%6==0:
print('第%d轮:%d批次:损失函数为:%.3f'%(epoch,start/128+1,loss.item()))
#模型训练结束,用1-100去做测试数据
testX = torch.tensor([binary_encode(i,10) for i in range(1,101)],dtype=torch.float32)
testY = torch.tensor([fizzbuzz_encode(i) for i in range(1,101)],dtype=torch.long)
pre_Y = model(testX).max(1)[1]
print(pre_Y)
pyTorch学习fizzbuzz训练
最新推荐文章于 2023-03-18 10:20:27 发布