原文地址:http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow(作者:Joel Grus)
话说Fizz Buzz是什么鬼?
Fizz Buzz是洋人小朋友在学除法时常玩的游戏,玩法是:从1数到100,如果遇见了3的倍数要说Fizz,5的倍数就说Buzz,如果即是3的倍数又是5的倍数就说FizzBuzz。
最后演变为一个编程面试题:写一个程序输出1到100,但是如果遇到数字为3的倍数时输出Fizz,5的倍数输出Buzz,既是3的倍数又是5的倍数输出FizzBuzz。
面试中
面试官:你好,在开始面试之前要不要来杯水或来杯咖啡提提神。
我:不用,咖啡啥的我已经喝的够多了,三鹿也喝了不少。
面试官:很好,很好,你不介意在小白板上写代码吧。
我:It’s the only way I code!
面试官:….
我:那只是个笑话。
面试官:好吧,你是否熟悉”fizz buzz”。
我:….
面试官:你到底知不知道”fizz buzz”?
我:我知道”fizz buzz”,我只是不敢相信这么牛叉的IT巨头竟然问这个问题。
面试官:OK,我要你现在写一个程序输出1到100,但是遇到数字为3的倍数时输出Fizz,5的倍数输出Buzz,既是3的倍数又是5的倍数输出FizzBuzz。
我:额,这个,我会!
面试官:很好,我们发现不会解这个问题的人不能胜任我们这里的工作。
我:….
面试官:这是板擦和马克笔。
我:[想了几分钟]
面试官:需不需要帮忙。
我:不,不用。首先先容我导入一些标准库:
1
2
|
import
numpy
as
np
import
tensorflow
as
tf
|
面试官:你知道我们的问题是”fizz buzz”吧?
我:当然,现在让我们来讨论一下模型,我正在想一个简单的只有一个隐藏层的感知器。
面试官:感知器?
我:或神经网络,不管你怎么叫它。给它输入数字,然后它能给我们输出数字对应的”fizz buzz”。但是,首先我们需要把数字转为向量,最简单的方法是把数字转换为二进制表示。
面试官:二进制?
我:你懂的,就是一堆0和1,像这样:
1
2
|
def
binary_encode
(
i
,
num_digits
)
:
return
np
.
array
(
[
i
>>
d
&
1
for
d
in
range
(
num_digits
)
]
)
|
面试官:[盯着小白板看了一分钟]
我:输出应该用one-hot编码表示”fizz buzz”:
1
2
3
4
5
|
def
fizz_buzz_encode
(
i
)
:
if
i
%
15
==
0
:
return
np
.
array
(
[
0
,
0
,
0
,
1
]
)
# FizzBuzz
elif
i
%
5
==
0
:
return
np
.
array
(
[
0
,
0
,
1
,
0
]
)
# Buzz
elif
i
%
3
==
0
:
return
np
.
array
(
[
0
,
1
,
0
,
0
]
)
# Fizz
else
:
return
np
.
array
(
[
1
,
0
,
0
,
0
]
)
|
面试官:等一等,够了!
我:没错,基本的准备工作已经完成了。现在我们需要生成一个训练数据,我们不用1到100训练,为了增加难度,我们使用100-1024训练:
1
2
3
|
NUM_DIGITS
=
10
trX
=
np
.
array
(
[
binary_encode
(
i
,
NUM_DIGITS
)
for
i
in
range
(
101
,
2
*
*
NUM_DIGITS
)
]
)
trY
=
np
.
array
(
[
fizz_buzz_encode
(
i
)
for
i
in
range
(
101
,
2
*
*
NUM_DIGITS
)
]
)
|
面试官:….
我:现在就可以使用TensorFlow搭模型了,我还不太确定隐藏层要使用多少”神经元”,10,够不?
面试官:….
我:100也许要好点,以后还可以再改:
1
|
NUM_HIDDEN
=
100
|
定义输入和输出:
1
2
|
X
=
tf
.
placeholder
(
"float"
,
[
None
,
NUM_DIGITS
]
)
Y
=
tf
.
placeholder
(
"float"
,
[
None
,
4
]
)
|
面试官:你到底要搞哪样。
我:哦,这个网络只有两层深,一个隐藏层和一个输出层。下面,让我们使用随机数初始化“神经元”的权重:
1
2
3
4
5
|
def
init_weights
(
shape
)
:
return
tf
.
Variable
(
tf
.
random_normal
(
shape
,
stddev
=
0.01
)
)
w_h
=
init_weights
(
[
NUM_DIGITS
,
NUM_HIDDEN
]
)
w_o
=
init_weights
(
[
NUM_HIDDEN
,
4
]
)
|
现在我们可以定义模型了,就像我前面说的,一个隐藏层。激活函数用什么呢,我不知道,就用ReLU吧:
1
2
3
|
def
model
(
X
,
w_h
,
w_o
)
:
h
=
tf
.
nn
.
relu
(
tf
.
matmul
(
X
,
w_h
)
)
return
tf
.
matmul
(
h
,
w_o
)
|
我们可以使用softmax cross-entrop做为coss函数,并且试图最小化它。
1
2
3
4
|
py_x
=
model
(
X
,
w_h
,
w_o
)
cost
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
py_x
,
Y
)
)
train_op
=
tf
.
train
.
GradientDescentOptimizer
(
0.05
)
.
minimize
(
cost
)
|
面试官:….
我:当然,最后还要取概率最大的预测做为结果:
1
|
predict_op
=
tf
.
argmax
(
py_x
,
1
)
|
面试官:在你偏离轨道过远之前,我要提醒你,我们的问题是生成1到100的”fizz buzz”。
我:哦,没错,现在predict_op
输出的值是0-3,还要转换为”fizz buzz”输出:
1
2
|
def
fizz_buzz
(
i
,
prediction
)
:
return
[
str
(
i
)
,
"fizz"
,
"buzz"
,
"fizzbuzz"
]
[
prediction
]
|
面试官:….
我:现在我们可以训练模型了,首先创建一个session并初始化变量:
1
2
|
with
tf
.
Session
(
)
as
sess
:
tf
.
global_variables_initializer
(
)
.
run
(
)
|
就训练1000个大周天吧。
面试官:….
我:也许不够,为了保险就训练10000个大周天。我们的训练数据是生成的序列,最好在每个大周天随机打乱一下:
1
2
3
|
for
epoch
in
range
(
10000
)
:
p
=
np
.
random
.
permutation
(
range
(
len
(
trX
)
)
)
trX
,
trY
=
trX
[
p
]
,
trY
[
p
]
|
每次取多少个样本进行训练,我不知道,128怎么样?
1
|
BATCH_SIZE
=
128
|
训练:
1
2
3
|
for
start
in
range
(
0
,
len
(
trX
)
,
BATCH_SIZE
)
:
end
=
start
+
BATCH_SIZE
sess
.
run
(
train_op
,
feed_dict
=
{
X
:
trX
[
start
:
end
]
,
Y
:
trY
[
start
:
end
]
}
)
|
我们还能看准确率:
1
|
print
(
epoch
,
np
.
mean
(
np
.
argmax
(
trY
,
axis
=
1
)
==
sess
.
run
(
predict_op
,
feed_dict
=
{
X
:
trX
,
Y
:
trY
}
)
)
)
|
面试官:你是认真的吗?
我:是,看准确率提升曲线非常有帮助。
面试官:….
我:模型训练完了,现在是fizz buzz时间。给模型输入1-100的二进制表示:
1
2
|
numbers
=
np
.
arange
(
1
,
101
)
teX
=
np
.
transpose
(
binary_encode
(
numbers
,
NUM_DIGITS
)
)
|
预测fizz buzz,大功告成:
1
2
3
4
|
teY
=
sess
.
run
(
predict_op
,
feed_dict
=
{
X
:
teX
}
)
output
=
np
.
vectorize
(
fizz_buzz
)
(
numbers
,
teY
)
print
(
output
)
|
面试官:….
我:这就是你要的”fizz buzz”。
面试官:够了,我们会在联系你。
我:联系我!这可真喜人。
面试官:….
后记
我没有得到offer,于是我运行了一下这个代码,事实证明有一些输出是错的。感谢机器学习十八代!!
1
2
3
4
5
6
7
8
9
|
[
'buzz'
'2'
'fizz'
'buzz'
'buzz'
'fizz'
'7'
'8'
'fizz'
'buzz'
'11'
'fizz'
'13'
'14'
'fizzbuzz'
'fizz'
'17'
'fizz'
'19'
'buzz'
'fizz'
'22'
'23'
'fizz'
'buzz'
'26'
'fizz'
'28'
'29'
'fizzbuzz'
'31'
'32'
'fizz'
'34'
'buzz'
'fizz'
'37'
'38'
'fizz'
'buzz'
'41'
'fizz'
'43'
'44'
'fizzbuzz'
'46'
'47'
'fizz'
'fizz'
'buzz'
'fizz'
'52'
'fizz'
'fizz'
'buzz'
'56'
'fizz'
'58'
'59'
'fizzbuzz'
'61'
'62'
'fizz'
'64'
'buzz'
'fizz'
'67'
'68'
'fizz'
'buzz'
'71'
'fizz'
'73'
'74'
'fizzbuzz'
'76'
'77'
'fizz'
'79'
'buzz'
'fizz'
'82'
'83'
'fizz'
'buzz'
'86'
'fizz'
'88'
'89'
'fizzbuzz'
'91'
'92'
'fizz'
'94'
'buzz'
'fizz'
'97'
'98'
'fizz'
'buzz'
]
|
也许我应该使用更深的网络。
作者@斗大的熊猫