多分类问题以IRIS数据集为例

数据集

import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf

# 使用tensorflow1.0
tf = tf.compat.v1
tf.disable_v2_behavior()

# 读取数据并以此命名为花萼长度e_cd、花萼宽度e_kd、花瓣长度b_cd、花瓣宽度b_kd、分类结果cat
data = pd.read_csv('../../dataset/iris.data', names=['e_cd', 'e_kd', 'b_cd', 'b_kd', 'cat'])
sns.pairplot(data)
# plt.show()
print(data.cat.unique())
# 将分类进行独热编码
data['c1'] = np.array(data['cat'] == 'Iris-setosa').astype(np.float32)
data['c2'] = np.array(data['cat'] == 'Iris-versicolor').astype(np.float32)
data['c3'] = np.array(data['cat'] == 'Iris-virginica').astype(np.float32)
print(data)
target = np.stack([data.c1.values, data.c2.values, data.c3.values])
shuju = np.stack([data.e_cd.values, data.e_kd.values, data.b_cd.values, data.b_kd.values]).T
target = np.stack([data.c1.values, data.c2.values, data.c3.values]).T
print(np.shape(shuju), np.shape(target))

# 定义网络
x = tf.placeholder('float', shape=[None, 4])
y = tf.placeholder('float', shape=[None, 3])
weight = tf.Variable(tf.truncated_normal([4, 3]))
bias = tf.Variable(tf.truncated_normal([3]))
combine_input = tf.matmul(x, weight) + bias
pred = tf.nn.softmax(combine_input)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=combine_input))
# 比较最大值的索引是否相同
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
train_step = tf.train.AdamOptimizer(0.0005).minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(10000):
    index = np.random.permutation(len(target))
    shuju = shuju[index]
    target = target[index]
    sess.run(train_step, feed_dict={x: shuju, y: target})
    if i % 1000 == 0:
        print(sess.run((loss, accuracy), feed_dict={x: shuju, y: target}))


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值