TM手写字识别实验分析

手写字识别实验具体操作可以参考这个 Tsetlin Machine复现手写字识别MNIST Demo
首先看下面的代码

from pyTsetlinMachine.tm import MultiClassTsetlinMachine
import numpy as np
from time import time

from keras.datasets import mnist

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

X_train = np.where(X_train.reshape((X_train.shape[0], 28*28)) > 75, 1, 0) 
X_test = np.where(X_test.reshape((X_test.shape[0], 28*28)) > 75, 1, 0) 

tm = MultiClassTsetlinMachine(2000, 50, 10.0)

print("\nAccuracy over 250 epochs:\n")
for i in range(250):
	start_training = time()
	tm.fit(X_train, Y_train, epochs=1, incremental=True)
	stop_training = time()

	start_testing = time()
	result = 100*(tm.predict(X_test) == Y_test).mean()
	stop_testing = time()

	print("#%d Accuracy: %.2f%% Training: %.2fs Testing: %.2fs" % (i+1, result, stop_training-start_training, stop_testing-start_testing))

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

从这里读取的是X_train训练和Y_train也就是对应的标签,X_test,Y_test同理。
为了理解数据格式,我们读取它具体内容。

print("训练样本的维度为:",X_train.ndim)
print("训练样本的形状为:",X_train.shape)
print("训练样本的元素数量为:",X_train.size)
print("训练样本的数据类型为:",X_train.dtype)

可以得到,X_train的维度是(60000,28,28),Y_train也同理是(,60000)与之一一对应。
为了方便理解我们可以查看它的内容。

for i in range(0,28):
    for j in range(0,28):
        print("%.1f" % x_train[0][i][j] , end=" ")
    print()

如下,可以看出它是5。Y_train的0的位置也对应的是5。
在这里插入图片描述
接着的是,

X_train = np.where(X_train.reshape((X_train.shape[0], 28*28)) > 75, 1, 0) 
X_test = np.where(X_test.reshape((X_test.shape[0], 28*28)) > 75, 1, 0) 

注意这里X_train变成了二维数组,每行存储的一个28*28的图像信息,也就是和上面的数据对应。TM的训练都是需要bool,所以将像素值75(应该是一个经验值)设置为阈值,01化。
具体的值如下。

这样一来,数据和标签都制作完成了,然后创造tm机,输入训练预测。按照这个思路,我们可制作其它数据的数据集,比如汽车,花朵之类,分别做成txt读取或者元组,之后再去做实验吧。
之后的TM训练过程,可以在学习详细的原理之后再进行分析。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值