TensorFlow学习之二:kaggle上minist手写识别CNN优化

在上次提交的代码后,基于上次的结果采用CNN神经网络,将准确率提高了一个等级,下面附上训练代码和对应的预测代码,这次在训练中保存模型后,在预测过程中读取模型的时候采用了加载整个神经网络模型的方法。期间也遇到了很多报错 。经过慢慢分析,一一解决了。

训练代码,我这里训练只迭代了10次,实际中大家可以根据实际来训练,因为我觉得用个人的笔记本训练太慢了,就没训练太久,最后训练结果如下:

实现代码如下:

# !/usr/bin/env python3

import tensorflow as tf
import pandas as pd
import numpy as np

def read_data(filename):#读取数据
    data=pd.read_csv(filename)
    return data

def handle_data(data):
    y_data=data['label'].values.ravel() #获取标签数据
    data.drop(labels='label',axis=1,inplace=True) #Image 数据
    return data,y_data
def train_val_split(x_data,y_data):
    large=x_data.shape[0]
    print(large)
    x_train=x_data.iloc[:large-200,].div(255.0)#由于数据值范围在0-255,部分值差异太大,故进行0-1标准化
    y_train=y_data[:large-200,].astype(np.float32) #需要保证数据类型一致性
    x_val=x_data.iloc[large-200:,].div(255.0)#由于数据值范围在0-255,部分值差异太大,故进行0-1标准化,此为验证Images数据,用来验证后面的模型的准确率
    y_val=y_data[large-200:,].astype(np.float32)#此为Label数据,用来验证后面模型的准确率
    return x_train,y_train,x_val,y_val
#one_hot编码
def one_hot(data):
    num_class=len(np.unique(data))#获取label的个数,这里我们的手写识别数字范围是0~9,所以num_class=10
    print(num_class)
    num_lables=data.shape[0]
    index_offset=np.arange(num_lables)*num_class
    lables_one_hot=np.zeros((num_lables,num_class))
    print(data.ravel())
    lables_one_hot.flat[index_offset+data.ravel()]=1
    return lables_one_hot

def train_model(x_train,y_train,x_val,y_val,n): #训练模型并保存模型 此处模型用的softmax回归模型训练y=w*x+b
    x=tf.placeholder("float",[None,784])
    w=tf.Variable(tf.zeros([784,10]),name='w')
    b=tf.Variable(tf.zeros([1,10]),name='b')  #在这里的时候需要保证矩阵的维度在进行 y=x*w+b后直接都是一致的,否则会报错 这里维度为[none,10]=[none,784]*[784,10]+[1,10]
    y=tf.nn.softmax(tf.matmul(x,w)+b) #定义模softmax 函数 这里需要注意我们在模型训练的时候y值存储的是0,1值,比如如果label为5,则在实际中的标识为[0,0,0,0,1,0,0,0,0,0],softmax 激活函数通常用在分类问题中
    y_=tf.placeholder("float",[None,10])
    cross_entropy=-tf.red
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值