在MNIST手写字数据集中,我们导入的数据和标签都是预先处理好的,但是在实际的训练中,数据和标签往往需要自己进行处理。
以手写数字识别为例,我们需要将0-9共十个数字标签转化成onehot标签。例如:数字标签“6”转化为onehot标签就是[0,0,0,0,0,0,1,0,0,0].
首先获取需要处理的标签的个数:
batch_size = tf.size(labels)
1
假设输入了6张手写字图片,那么对应的标签数为6,batch_size=6.
tf.size(input) 函数为获取输入tensor的元素数量。
例:设 t = [[1,2],[3,4]],则 tf.size(t) 的结果为4。
然后我们知道onehot标签的shape为[batch_size,10],采用稀疏编码的方法,在onehot标签中选择能够代表标签的位置,将其置为1,其余位置置为0。因此要选择onehot中需要置为1的位置的坐标。
labels = tf.expand_dims(labels, 1)
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
concated = tf.concat([indices, labels],1)
这里得到的concated就是所要