tensorflow中tf.one_hot()函数的作用是将一个值化为一个概率分布的向量,一般用于分类问题。
具体用法以及作用见以下代码:
- import numpy as np
- import tensorflow as tf
- SIZE=6
- CLASS=8
- label1=tf.constant([0,1,2,3,4,5,6,7])
- sess1=tf.Session()
- print('label1:',sess1.run(label1))
- b = tf.one_hot(label1,CLASS,1,0)
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- sess.run(b)
- print('after one_hot',sess.run(b))
label1: [0 1 2 3 4 5 6 7]
after one_hot:
[[1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0]
[0 0 0 1 0 0 0 0]
[0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0]
[0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 1]]
tf.one_hot在看conditionGAN的时候注意到label的输入要把它转换成one-hot形式,再与噪声z进行tf.concat输入,之前看的时候忽略了,现在再看才算明白为什么。
tf.one_hot(
indices,#输入,这里是一维的
depth,# one hot dimension.
on_value=None,#output 默认1
off_value=None,#output 默认0
axis=None,#根据我的实验,默认为1
dtype=None,
name=None
)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
代码
import tensorflow as tf
import numpy as np
z=np.random.randint(0,10,size=[10])
y=tf.one_hot(z,10,on_value=1,off_value=None,axis=0)
with tf.Session()as sess:
print(z)
print(sess.run(y))
[5 7 7 0 5 5 2 0 0 0]
[[0 0 0 1 0 0 0 1 1 1]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[1 0 0 0 1 1 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 1 1 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
z=np.random.randint(0,10,size=[10])
y=tf.one_hot(z,10,on_value=1,off_value=None)
y1=tf.one_hot(z,10,on_value=1,off_value=None,axis=1)
with tf.Session()as sess:
print(z)
print(sess.run(y))
print("axis=1按行排", sess.run(y1))
[6 3 4 9 6 5 5 1 2 1]
[[0 0 0 0 0 0 1 0 0 0]
[0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]]
axis=1按行排 [[0 0 0 0 0 0 1 0 0 0]
[0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]]
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
感觉实际用的时候可以不传入axis值。可以看到经过one-hot的处理,输入的维度变成了10×depth,值也变成了0和1.
下面说在condition GAN中要输入标签信息y,怎样处理的。
y是mnist的标签值,0和10之间的整数,尺寸为[BATCH],经过one-hot处理后维度变成了[BATCH,10]值也是0和1,此时再与噪声z按列(axis=1)连接,变成条件GAN的输入。因此one-hot操作是必须的,这个处理在infoGAN中将z,categorical latent code、continuous latent code连接在一起输入也要用到。
y = tf.one_hot(y, 10, name='label_onehot')
z = tf.random_uniform([BATCH, 100], -1, 1, name='z_train')
tf.concat([z, y], 1)
参考文献:https://blog.csdn.net/m0_37561765/article/details/78207508
https://blog.csdn.net/wenqiwenqi123/article/details/78055740