【theano-windows】学习笔记九——softmax手写数字分类

前言

上一篇博客折腾了数据集的预备知识, 接下来按照官方的Deep learning 0.1 documentation一步步走, 先折腾softmax, 关于softmaxlogistic回归分类的联系, 我在之前写过一个小博客

国际惯例, 参考博客走一波:

Classifying MNIST digits using Logistic Regression

softmax理论及代码解读——UFLDL

softmax简介

直接码公式了,理论就不说了:

各标签概率:

P(y(i)=1|x(i);W,b)=ewix(i)+biki=1ewix(i)+bi

其实每个单元值得计算就是 ewix+bi ,其中 wi 就是连接输入神经元和第 i 个神经元的权重,bi就是第 i 个神经元的偏置, 最后为了保证概率和为1, 进行了归一化而已.

获取当前预测值,也就是概率最大的那个标签, 可以用它来计算准确率

ypred=argmaxiP(Y=i|xi;W,b)

代价函数为负对数似然:

J(θ)=1mi=1mj=1k1{y(i)=j}logP(y(i)=1|x(i);W,b)

梯度就不求了,直接用 theano的自动梯度函数 grad()

算法实现

【PS】官方代码加上注释竟然能达到五百多行,这是有多么生猛。

导入包

导入包就不用说了,需要用到处理文件路径, 和解压以及读取数据的模块

import theano
import theano.tensor as T
import numpy as np
import cPickle,gzip#读取数据解压用的
import os#路径相关操作需要
import timeit#时间

读取数据

#读取数据集
def load_data(dataset):
    data_dir,data_file=os.path.split(dataset)
    if os.path.isfile(dataset):
        with gzip.open(dataset,'rb') as f:
            train_set,valid_set,test_set=cPickle.load(f)
    #共享数据集
    def shared_dataset(data_xy,borrow=True):
        data_x,data_y=data_xy
        shared_x=theano.shared(np.asarray(data_x,dtype=theano.config.floatX),borrow=borrow)
        shared_y=theano.shared(np.asarray(data_y,dtype=theano.config.floatX),borrow=borrow)
        return shared_x,T.cast(shared_y,'int32')

    #定义三个元组分别返回训练集,验证集,测试集
    train_set_x,train_set_y=shared_dataset(train_set)
    valid_set_x,valid_set_y=shared_dataset(valid_set)
    test_set_x,test_set_y=shared_dataset(test_set)
    rval=[(train_set_x,train_set_y),(valid_set_x,valid_set_y),(test_set_x,test_set_y)]
    return rval

分类器函数

定义一个分类器类, 用于存储模型参数以及实现负对数似然和模型当前预测误差的计算, 以及负对数似然就是上面那个带有log的式子, 将当前样本所应属的类别的概率相加求均值即可. 误差就是当前样本有多少个识别错误了, 看看当前样本所属的最大概率类别是不是它所应属的类别, 然后求均值就得到了误差

#定义分类器相关操作
class LogisticRegression(object):
    def __init__(self,input,n_in,n_out):
        #共享权重
        self.W=theano.shared(value=np.zeros((n_in,n_out),dtype=theano.config.floatX),
                            name='W',
                            borrow=True)
        #共享偏置
        self.b=theano.shared(value=np.zeros((n_out,),dtype=theano.config.floatX),
                            name='b',
                            borrow=True)
        #softmax函数
        self.p_y_given_x=T.nnet.softmax(T.dot(input,self.W)+self.b)
        #预测值
        self.y_pred=T.argmax(self.p_y_given_x,axis=1)
        self.params=[self.W,self.b]#模型参数
        self.input=input#模型输入

    #定义负对数似然
    def negative_log_likelihood(self,y):
        return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]),y])

    #定义误差
    def errors(self, y):
        if y.ndim != self.y_pred.ndim:#查看维度是不是一样
            raise TypeError(
                'y should have the same shape as self.y_pred',
                ('y', y.type, 'y_pred', self.y_pred.type)
            )
        if y.dtype.startswith('int'):#查看y的类型是不是正确的
            return T.mean(T.neq(self.y_pred, y))#neq是判断相等(0)和不等(1)
        else:
            raise NotImplementedError()

训练

以下操作都放在一个函数sgd_mnist

def sgd_mnist(learning_rate=0.13,
              n_epochs=1000,
              dataset='mnist.pkl.gz',
              batch_size=600):

首先通过读取函数load_data读取数据, 并且建立好分批索引

#处理数据集
    datasets=load_data(dataset)#读取数据集
    train_set_x,train_set_y=datasets[0]#训练集
    valid_set_x,valid_set_y=datasets[1]#验证集
    test_set_x,test_set_y=datasets[2]#测试集
    print '验证集大小',valid_set_x.get_value().shape#看看是否读取成功

    #总共多少批数据,注意共享变量数据用get_value获取
    n_train_batches=train_set_x.get_value(borrow=True).shape[0]//batch_size
    n_valid_batches=valid_set_x.get_value(borrow=True).shape[0]//batch_size
    n_test_batches=test_set_x.get_value(borrow=True).shape[0]//batch_size

然后建立两个容器去存储每批输入的数据, 顺便想象一下, 我们后面每次在function中迭代计算就用givens去不断替换这个容器里面的数据就能实现每次迭代都是不同批次的数据, 注意数据是包含图片和标签的, 所以需要矩阵和向量存储

 #建立模型
    print '建立模型'
    index=T.lscalar()#索引数据所属批次
    x=T.matrix('x')#栅格化图片数据
    y=T.ivector('y')#数据标签,不要直接使用vector

然后把分类器初始化一下,其实就是权重和偏置初始化一下

 #初始化分类器:输入,输入大小,输出大小
    classifier=LogisticRegression(input=x,n_in=28*28,n_out=10)
    cost=classifier.negative_log_likelihood(y)

权重和偏置的更新梯度, 以及放到function去编译

#权重和偏置更新
    g_W=T.grad(cost=cost,wrt=classifier.W)#损失对权重的导数
    g_b=T.grad(cost=cost,wrt=classifier.b)#损失对偏置的导数

    #更新偏置和梯度,注意每次迭代要在givens中替换批数据
    updates=[(classifier.W,classifier.W-learning_rate*g_W),
            (classifier.b,classifier.b-learning_rate*g_b)]
    train_mode=theano.function(inputs=[index],
                              outputs=cost,
                              updates=updates,
                              givens={
                                  x:train_set_x[index*batch_size:(index+1)*batch_size],
                                  y:train_set_y[index*batch_size:(index+1)*batch_size]
                              })

当然我们机器学习中还有两个数据集:验证集和测试集, 都有各自的作用, 我们定义一下模型的测试函数

 #验证集测试模型
    valid_mode=theano.function(inputs=[index],
                              outputs=classifier.errors(y),
                              givens={
                                  x:valid_set_x[index*batch_size:(index+1)*batch_size],
                                  y:valid_set_y[index*batch_size:(index+1)*batch_size]
                              })
    #测试集误差
    test_mode=theano.function(inputs=[index],
                             outputs=classifier.errors(y),
                             givens={
                                 x:test_set_x[index*batch_size:(index+1)*batch_size],
                                 y:test_set_y[index*batch_size:(index+1)*batch_size]

接下来就是利用上一篇博客中所说的提前终止算法进行训练了

#训练开始, 使用提前终止法训练
    print '训练开始'
    patience=5000#初始patience
    patience_increase=2#增量
    improvement_threshold=0.995#性能提升阈值
    validation_frequency=min(n_train_batches,patience//2)#至少每个patience预测两次
    best_validation_loss=np.inf#最好的预测值
    test_score=0
    start_time=timeit.default_timer()
    done_looping=False#是否停止循环
    epoch=0#初始迭代次数
    while (epoch<n_epochs) and (not done_looping):
        epoch=epoch+1
        for minibatch_index in range(n_train_batches):
            minibatch_avg_cost=train_mode(minibatch_index)#对数似然目标函数值
            iter=(epoch-1)*n_train_batches+minibatch_index#当前迭代批次数
            #验证集误差
            if (iter+1)%validation_frequency==0:
                validation_loss=[valid_mode(i) for i in range(n_valid_batches)]
                this_validation_loss=np.mean(validation_loss)
                print 'epoch %i, minibatch %i/%i, validation error: %f %%' %\
                (epoch,minibatch_index+1,n_train_batches,this_validation_loss*100.)
                #阈值判断性能提升
                if this_validation_loss<best_validation_loss:
                    if this_validation_loss<best_validation_loss*improvement_threshold:
                        patience=max(patience,iter*patience_increase)
                    best_validation_loss=this_validation_loss#如果性能提升,就重新记录最优值
                    test_loss=[test_mode(i) for i in range(n_test_batches)]
                    test_score=np.mean(test_loss)
                    print 'epoch %i minibatch %i/%i,test error of best model %f%%' %\
                    (epoch,minibatch_index+1,n_train_batches,test_score*100.)
                    #存储最好的模型参数
                    with open('best_model.pkl','wb') as f:
                        cPickle.dump(classifier,f)
            if patience<=iter:
                done_looping=True
                break
    end_time=timeit.default_timer()
    print 'Optimization complete with best validation score of %f %%,'\
    'with test performance %f %%'\
    % (best_validation_loss * 100., test_score * 100.)

    print 'The code run for %d epochs,with %f epochs/sec'\
    %(epoch,1.*epoch/(end_time-start_time))

执行主函数进行训练

#开始训练
sgd_mnist()

结果

验证集大小 (10000L, 784L)
建立模型
训练开始
epoch 1, minibatch 83/83, validation error: 12.458333 %
epoch 1 minibatch 83/83,test error of best model 12.375000%
epoch 2, minibatch 83/83, validation error: 11.010417 %
epoch 2 minibatch 83/83,test error of best model 10.958333%
epoch 3, minibatch 83/83, validation error: 10.312500 %
epoch 3 minibatch 83/83,test error of best model 10.312500%
epoch 4, minibatch 83/83, validation error: 9.875000 %
epoch 4 minibatch 83/83,test error of best model 9.833333%
epoch 5, minibatch 83/83, validation error: 9.562500 %
epoch 5 minibatch 83/83,test error of best model 9.479167%
epoch 6, minibatch 83/83, validation error: 9.322917 %
epoch 6 minibatch 83/83,test error of best model 9.291667%
epoch 7, minibatch 83/83, validation error: 9.187500 %
epoch 7 minibatch 83/83,test error of best model 9.000000%
epoch 8, minibatch 83/83, validation error: 8.989583 %
epoch 8 minibatch 83/83,test error of best model 8.958333%
epoch 9, minibatch 83/83, validation error: 8.937500 %
epoch 9 minibatch 83/83,test error of best model 8.812500%
epoch 10, minibatch 83/83, validation error: 8.750000 %
epoch 10 minibatch 83/83,test error of best model 8.666667%
epoch 11, minibatch 83/83, validation error: 8.666667 %
epoch 11 minibatch 83/83,test error of best model 8.520833%
epoch 12, minibatch 83/83, validation error: 8.583333 %
epoch 12 minibatch 83/83,test error of best model 8.416667%
epoch 13, minibatch 83/83, validation error: 8.489583 %
epoch 13 minibatch 83/83,test error of best model 8.291667%
epoch 14, minibatch 83/83, validation error: 8.427083 %
epoch 14 minibatch 83/83,test error of best model 8.281250%
epoch 15, minibatch 83/83, validation error: 8.354167 %
epoch 15 minibatch 83/83,test error of best model 8.270833%
epoch 16, minibatch 83/83, validation error: 8.302083 %
epoch 16 minibatch 83/83,test error of best model 8.239583%
epoch 17, minibatch 83/83, validation error: 8.250000 %
epoch 17 minibatch 83/83,test error of best model 8.177083%
epoch 18, minibatch 83/83, validation error: 8.229167 %
epoch 18 minibatch 83/83,test error of best model 8.062500%
epoch 19, minibatch 83/83, validation error: 8.260417 %
epoch 20, minibatch 83/83, validation error: 8.260417 %
epoch 21, minibatch 83/83, validation error: 8.208333 %
epoch 21 minibatch 83/83,test error of best model 7.947917%
epoch 22, minibatch 83/83, validation error: 8.187500 %
epoch 22 minibatch 83/83,test error of best model 7.927083%
epoch 23, minibatch 83/83, validation error: 8.156250 %
epoch 23 minibatch 83/83,test error of best model 7.958333%
epoch 24, minibatch 83/83, validation error: 8.114583 %
epoch 24 minibatch 83/83,test error of best model 7.947917%
epoch 25, minibatch 83/83, validation error: 8.093750 %
epoch 25 minibatch 83/83,test error of best model 7.947917%
epoch 26, minibatch 83/83, validation error: 8.104167 %
epoch 27, minibatch 83/83, validation error: 8.104167 %
epoch 28, minibatch 83/83, validation error: 8.052083 %
epoch 28 minibatch 83/83,test error of best model 7.843750%
epoch 29, minibatch 83/83, validation error: 8.052083 %
epoch 30, minibatch 83/83, validation error: 8.031250 %
epoch 30 minibatch 83/83,test error of best model 7.843750%
epoch 31, minibatch 83/83, validation error: 8.010417 %
epoch 31 minibatch 83/83,test error of best model 7.833333%
epoch 32, minibatch 83/83, validation error: 7.979167 %
epoch 32 minibatch 83/83,test error of best model 7.812500%
epoch 33, minibatch 83/83, validation error: 7.947917 %
epoch 33 minibatch 83/83,test error of best model 7.739583%
epoch 34, minibatch 83/83, validation error: 7.875000 %
epoch 34 minibatch 83/83,test error of best model 7.729167%
epoch 35, minibatch 83/83, validation error: 7.885417 %
epoch 36, minibatch 83/83, validation error: 7.843750 %
epoch 36 minibatch 83/83,test error of best model 7.697917%
epoch 37, minibatch 83/83, validation error: 7.802083 %
epoch 37 minibatch 83/83,test error of best model 7.635417%
epoch 38, minibatch 83/83, validation error: 7.812500 %
epoch 39, minibatch 83/83, validation error: 7.812500 %
epoch 40, minibatch 83/83, validation error: 7.822917 %
epoch 41, minibatch 83/83, validation error: 7.791667 %
epoch 41 minibatch 83/83,test error of best model 7.625000%
epoch 42, minibatch 83/83, validation error: 7.770833 %
epoch 42 minibatch 83/83,test error of best model 7.614583%
epoch 43, minibatch 83/83, validation error: 7.750000 %
epoch 43 minibatch 83/83,test error of best model 7.593750%
epoch 44, minibatch 83/83, validation error: 7.739583 %
epoch 44 minibatch 83/83,test error of best model 7.593750%
epoch 45, minibatch 83/83, validation error: 7.739583 %
epoch 46, minibatch 83/83, validation error: 7.739583 %
epoch 47, minibatch 83/83, validation error: 7.739583 %
epoch 48, minibatch 83/83, validation error: 7.708333 %
epoch 48 minibatch 83/83,test error of best model 7.583333%
epoch 49, minibatch 83/83, validation error: 7.677083 %
epoch 49 minibatch 83/83,test error of best model 7.572917%
epoch 50, minibatch 83/83, validation error: 7.677083 %
epoch 51, minibatch 83/83, validation error: 7.677083 %
epoch 52, minibatch 83/83, validation error: 7.656250 %
epoch 52 minibatch 83/83,test error of best model 7.541667%
epoch 53, minibatch 83/83, validation error: 7.656250 %
epoch 54, minibatch 83/83, validation error: 7.635417 %
epoch 54 minibatch 83/83,test error of best model 7.520833%
epoch 55, minibatch 83/83, validation error: 7.635417 %
epoch 56, minibatch 83/83, validation error: 7.635417 %
epoch 57, minibatch 83/83, validation error: 7.604167 %
epoch 57 minibatch 83/83,test error of best model 7.489583%
epoch 58, minibatch 83/83, validation error: 7.583333 %
epoch 58 minibatch 83/83,test error of best model 7.458333%
epoch 59, minibatch 83/83, validation error: 7.572917 %
epoch 59 minibatch 83/83,test error of best model 7.468750%
epoch 60, minibatch 83/83, validation error: 7.572917 %
epoch 61, minibatch 83/83, validation error: 7.583333 %
epoch 62, minibatch 83/83, validation error: 7.572917 %
epoch 62 minibatch 83/83,test error of best model 7.520833%
epoch 63, minibatch 83/83, validation error: 7.562500 %
epoch 63 minibatch 83/83,test error of best model 7.510417%
epoch 64, minibatch 83/83, validation error: 7.572917 %
epoch 65, minibatch 83/83, validation error: 7.562500 %
epoch 66, minibatch 83/83, validation error: 7.552083 %
epoch 66 minibatch 83/83,test error of best model 7.520833%
epoch 67, minibatch 83/83, validation error: 7.552083 %
epoch 68, minibatch 83/83, validation error: 7.531250 %
epoch 68 minibatch 83/83,test error of best model 7.520833%
epoch 69, minibatch 83/83, validation error: 7.531250 %
epoch 70, minibatch 83/83, validation error: 7.510417 %
epoch 70 minibatch 83/83,test error of best model 7.500000%
epoch 71, minibatch 83/83, validation error: 7.520833 %
epoch 72, minibatch 83/83, validation error: 7.510417 %
epoch 73, minibatch 83/83, validation error: 7.500000 %
epoch 73 minibatch 83/83,test error of best model 7.489583%
Optimization complete with best validation score of 7.500000 %,with test performance 7.489583 %
The code run for 74 epochs,with 6.870845 epochs/sec

测试

一般测试分为两种:

  • 批量测试: 一次性丢一堆数据进去, 测试这一堆数据的准确率
  • 单样本测试: 自己手写一个数字, 然后调用模型预测一下

批样本测试

想一下, 我们训练的时候是分批丢进去的, 那么批量测试是否可以复制相同部分的代码进而减少代码量呢?(答案是挺麻烦的, 继续往下看)

  • 因为存储的就是训练好的LogisticRegression的值, 那么我们在测试的时候就可以直接调用它里面的函数, 而且跳过初始化__init__方法, 但是发现train_mode(),test_mode(),valid_mode()方法中的input都只是单纯的索引, 而非索引的批数据, 而LogisticRegression中的输入是真实数据而非索引, 那么是怎么将索引变成索引的数据而输入到LogisticRegression中的呢?答案就是通过初始化classifier=LogisticRegression(input=x,n_in=28*28,n_out=10), 这就实现了theano.fuction中的givens中得到的数据x被传递给分类器了.

  • 然而调用训练好的模型去做分类并不需要初始化, 那么我们也就不能使用批索引作为初始化inputs, 因为我们没法从givens中将数据丢给x, 进而丢入到分类器中. 说白了也就是不能进行如下操作

    
    #由于缺少分类器初始化函数, 而无法将index得到的x丢入到分类器
    
    
    #无法通过此代码实现批量准确率测试
    
    test_mode=theano.function(inputs=[index],
                               outputs=classifier.errors(y),
                               givens={
                                   x:test_set_x[index*batch_size:(index+1)*batch_size],
                                   y:test_set_y[index*batch_size:(index+1)*batch_size]

解决方法:

  • 第一个方法是对于python学得好的同学, 如果知道如何对classifier进行初始化以后, 再将权重替换给classifier即可, 操作类似下面几句话

    classifier=LogisticRegression(input=x,n_in=28*28,n_out=10)
    classifier.W=训练好的模型.W
    classifier.b=训练好的模型.b

    当然我尝试过上面这个代码, 暂时无法赋值成功, 可能我python功底不到家o(╯□╰)o(更新日志:之前忘记要使用set_value()才能设置共享参数的值了, 这里就不试了,在下一章节多层感知器可能会用到这个方法)

  • 第二个方法就是不使用LogisticRegression中的errors函数, 直接使用y_pred预测一下, 然后再与真实标签作对比, 类似于这样

    
    ## 批量预测准确率
    
    
    #读取数据集
    
    dataset='mnist.pkl.gz'
    datasets=load_data(dataset)
    test_set_x,test_set_y=datasets[2]
    test_set_x=test_set_x.get_value()
    
    #定义存储图片和标签的容器
    
    x=T.matrix('x')
    y=T.ivector('y')
    
    #定义批量测试参数
    
    batch_size=1000
    n_test_batches=test_set_x.shape[0]//batch_size
    
    #度量准确率的函数
    
    a=T.ivector('a')
    b=T.ivector('b')
    z=T.mean(T.eq(a,b))
    true_per=theano.function([a,b],z)
    
    #读取模型
    
    classifier=cPickle.load(open('best_model.pkl'))
    
    #编译函数
    
    predict_model=theano.function(inputs=[classifier.input],outputs=classifier.y_pred)
    
    #预测值
    
    for index in range(n_test_batches):
      x=test_set_x[index*batch_size:(index+1)*batch_size]
      y=test_set_y[index*batch_size:(index+1)*batch_size]
      predicted_values=predict_model(x)
      predicted_values=predicted_values.astype(np.int32)
      correct=true_per(predicted_values,y.eval())
      print 'Batch %d\'s correct ratio is %f %%' %(index,correct)

    结果

    Batch 0's correct ratio is 0.911000 %
    Batch 1's correct ratio is 0.886000 %
    Batch 2's correct ratio is 0.902000 %
    Batch 3's correct ratio is 0.905000 %
    Batch 4's correct ratio is 0.902000 %
    Batch 5's correct ratio is 0.941000 %
    Batch 6's correct ratio is 0.937000 %
    Batch 7's correct ratio is 0.955000 %
    Batch 8's correct ratio is 0.972000 %
    Batch 9's correct ratio is 0.914000 %

    我后来又想, 上述代码如果使用for循环去调用LogisticRegressionerrors()函数去获取每批数据的准确率, 那就在每个for循环中使用一次

    predict_model=theano.function(inputs=[classifier.input],
                               outputs=classifier.errors(test_set_y))

    每次的inputs是不同的数据, 后来发现这个test_set_y的标签并不是作为errors函数的输入, 也就是说即使你丢入一批数据, 这个标签也是一股脑丢进去的, 那么就会出现预测标签维度(批数据集)小于真实输入的标签维度(整个数据集), 那么解决方法就是改errors这个函数, 但是改完也得折腾训练之类的, 所以我就偷懒不改了, 此想法作废

    但是呢, 如果你的测试集不大, 那就一股脑全丢进去, 代码更简单,直接使用error函数

    
    #读取数据集
    
    dataset='mnist.pkl.gz'
    datasets=load_data(dataset)
    test_set_x,test_set_y=datasets[2]
    test_set_x=test_set_x.get_value()
    
    #读取模型
    
    classifier=cPickle.load(open('best_model.pkl'))
    
    #编译函数
    
    predict_model=theano.function(inputs=[classifier.input],
                               outputs=classifier.errors(test_set_y))
    
    #预测值
    
    predicted_values=predict_model(test_set_x)
    print "Predicted values for the first 10 examples in test set:",1-predicted_values
    
    #Predicted values for the first 10 examples in test set: 0.9225
    

单样本测试

  • 如果是测试集的某个样本想取出来看看, 比如想看看第九个样本的预测值和真实值, 那么

    
    #读取数据集
    
    dataset='mnist.pkl.gz'
    datasets=load_data(dataset)
    test_set_x,test_set_y=datasets[2]
    test_set_x=test_set_x.get_value()
    
    #读取模型
    
    classifier=cPickle.load(open('best_model.pkl'))
    
    #编译函数
    
    predict_model=theano.function(inputs=[classifier.input],
                               outputs=classifier.y_pred)
    
    #预测值
    
    predicted_values=predict_model(test_set_x[9:10])
    print "Predicted values for the first 10 examples in test set:",predicted_values
    print "Real values for the first 10 examples in test set:",test_set_y[9:10].eval()
    '''
    Predicted values for the first 10 examples in test set: [9]
    Real values for the first 10 examples in test set: [9]
    '''

    当然上述代码如果改成predicted_values=predict_model(test_set_x[:10])就是测试前十个样本0-9的标签啦.

  • 如果是自己的手写数字图片, 比如我之前写的博客【caffe-Windows】mnist实例编译之model的使用-classification中的手写数字样本(密码bead).
    吸取当时在caffe中识别自己手写数字的教训, 我们需要核对的有:

    1. 读取通道顺序(高*宽*通道?宽*高*通道?)
    2. 数据是否需要被归一化(一般都是一定的)
    3. 因为是pythontheano, 数据类型和维度一定要统一

    第一点就不说了(因为我不想测试, 咳咳); 第二点归一化(除以255), 第三点要注意了:

    dataset='mnist.pkl.gz'
    datasets=load_data(dataset)
    test_set_x,test_set_y=datasets[2]
    test_set_x=test_set_x.get_value()
    print test_set_x[9:10].dtype#float32
    print type(test_set_x[9:10])#<type 'numpy.ndarray'>
    print test_set_x[9:10].shape#(1L, 784L)

    好了,接下来我们把自己的图片转换一波, 在丢到模型中

    from skimage import io
    import matplotlib.pyplot as plt
    mnist_data=io.imread('./binarybmp/3.bmp')
    io.imshow(mnist_data)
    plt.show()
    
    img_mnist=np.array([mnist_data.reshape(1,28*28)],dtype=np.float32)
    img_mnist=img_mnist/255.0
    
    
    #读取模型
    
    classifier=cPickle.load(open('best_model.pkl'))
    
    #编译函数
    
    predict_model=theano.function(inputs=[classifier.input],
                               outputs=classifier.y_pred)
    print img_mnist[0].dtype
    
    #预测值
    
    predicted_values=predict_model(img_mnist[0])
    print "Predicted values for our example:",predicted_values

    这里写图片描述

结语

本次学习遇到的主要有有很多, 前面的坑忘记了,但是最后一个坑是如何取出theano.tensor.cast后的数据, 答案是eval()函数, 还有一个坑是经常出现这个错误:

theano.compile.function_module.UnusedInputError: theano.function was asked to create a function computing outputs given certain inputs, but the provided input variable at index 0 is not part of the computational graph needed to compute the outputs: X. To make this error into a warning, you can pass the parameter on_unused_input='warn' to theano.function. To disable it completely, use on_unused_input='ignore'.

虽然我忘记怎么解决的了,但是一定是代码出现了错误, 不要尝试什么warn或者ignore之类的, 仔细核对代码, 认真查看每一个变量类型即可, 说的轻松, 这个softmax手写数字折腾了 我两天, 好怀念matlab咳咳咳咳咳咳

因为是初学python, 而且theano也是现学现卖, 所以整个历程很可能出现各种错误, 希望看到这篇学习记录的新手多多一起讨论,也希望大佬们可以多多给出建议, 非常感谢.

code地址:

官方代码:链接: https://pan.baidu.com/s/1jIL0M0m 密码: h2dt

本文代码:链接: https://pan.baidu.com/s/1catYii 密码: 693s

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值