TensorFlow中文社区](%28http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html%29)介绍很详细的介绍了识别手写数字的案例。
自己也想敲出来看看,奈何数据导入网址打不开。想到了seaborn中的数据。因此,本次用seaborn中鸢尾花分类的数据代替,基本的流程都是一样的。
导入需要的库文件
import numpy as np
import tensorflow as tf
import pandas as pd
import seaborn as sns
导入数据并处理
one-hot编码
iris = sns.load_dataset("iris")
data = pd.get_dummies(iris)
其中,pd,get_dummies()函数是对数据进行one-hot编码。简而言之,加入有3个分类,那么这三个分类的编码分别:
[1,0,0] 对应第一类
[0,1,0] 对应第二类
[0,0,1] 对应第san类
在本例中就是: