tenflow数据集_TensorFlow基本模型之最近邻

最近邻算法简介

k近邻模型的核心就是使用一种距离度量,获得距离目标点最近的k个点,根据分类决策规则,决定目标点的分类。[2]

距离度量(L1范数):

4047edcc6946

image

K值选择:这里k为10。

分类决策规则:k近邻的分类决策规则是最为常见的简单多数规则,也就是在最近的K个点中,哪个标签数目最多,就把目标点的标签归于哪一类。

Tensorflow 最近邻

import numpy as np

import tensorflow as tf

导入 mnist数据集

# Import MINST data

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./data/", one_hot=True)

Extracting ./data/train-images-idx3-ubyte.gz

Extracting ./data/train-labels-idx1-ubyte.gz

Extracting ./data/t10k-images-idx3-ubyte.gz

Extracting ./data/t10k-labels-idx1-ubyte.gz

构建模型

# In this example, we limit mnist data

Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)

Xte, Yte = mnist.test.next_batch(10) #10 for testing

# tf Graph Input

xtr = tf.placeholder("float", [None, 784])

xte = tf.placeholder("float", [784])

# Nearest Neighbor calculation using L1 Distance

# Calculate L1 Distance

distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)

# Prediction: Get min distance index (Nearest neighbor)

pred = tf.argmin(distance, 0)

补充:Tenosrflow中基本算术运算函数:[1]

tf.add(x,y,name=None) # 求和运算

tf.subtract(x,y,name=None) # 减法运算

tf.multiply(x,y,name=None) #乘法运算

tf.div(x,y,name=None) #除法运算

tf.mod(x,y,name=None) # 取模运算

tf.abs(x,name=None) #求绝对值

tf.negative(x,name=None) #取负运算(y=-x)

tf.sign(x,name=None) #返回符合x大于0,则返回1,小于0,则返回-1

tf.reciprocal(x,name=None) #取反运算

tf.square(x,name=None) #计算平方

tf.round(x,name=None) #舍入最接近的整数

tf.pow(x,y,name=None) #幂次方

训练

accuracy = 0.

# Initialize the variables (i.e. assign their default value)

init = tf.global_variables_initializer()

# Start training

with tf.Session() as sess:

sess.run(init)

# loop over test data

for i in range(len(Xte)):

# Get nearest neighbor

# 5000个样本点分别和10个测试点计算距离

nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]})

print(nn_index)

# Get nearest neighbor class label and compare it to its true label

print ("Test", i, "Prediction:", np.argmax(Ytr[nn_index]), \

"True Class:", np.argmax(Yte[i]))

# Calculate accuracy

if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):

accuracy += 1./len(Xte)

print ("Done!")

print ("Accuracy:", accuracy)

190

Test 0 Prediction: 9 True Class: 9

475

Test 1 Prediction: 5 True Class: 5

3152

Test 2 Prediction: 7 True Class: 7

2413

Test 3 Prediction: 2 True Class: 2

1088

Test 4 Prediction: 2 True Class: 2

1427

Test 5 Prediction: 2 True Class: 2

4743

Test 6 Prediction: 7 True Class: 7

4826

Test 7 Prediction: 6 True Class: 6

4099

Test 8 Prediction: 5 True Class: 5

2421

Test 9 Prediction: 5 True Class: 5

Done!

Accuracy: 0.9999999999999999

参考

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值