Tensorflow入门(一)mnist手写数字识别

环境搭建:linux+anaconda/jupyter+tensorflow

方法:安装anaconda后直接
pip install tensorflow 即可

踩过的坑:
不要用虚拟环境。原因:使用虚拟环境后,anaconda和tensorflow不在一个环境中,jupyter无法import。(虽然网上有在虚拟环境中重新安装ipython和jupyter的解决办法,但是我尝试多次都由于速度非常慢而安装失败,更换镜像源仍然没有成功,有资料说可能是防火墙的问题,但自己缺乏linux命令基础没有尝试,如果有朋友解决了欢迎交流。)

我的第一个tensorflow模型:mnist手写数字识别,准确率:95%

参考:简单粗暴Tensorflow https://tf.wiki/zh/basic.html

class DataLoader():
    def __init__(self):
        mnist=tf.contrib.learn.datasets.load_dataset("mnist")
        self.train_data=mnist.train.images
        self.train_labels=np.asarray(mnist.train.labels,dtype=np.int32)
        self.eval_data=mnist.test.images
        self.eval_labels=np.asarray(mnist.test.labels,dtype=np.int32)
    
    def get_batch(self,batch_size):
        index=np.random.randint(0,np.shape(self.train_data)[0],batch_size)
        return self.train_data[index,:],self.train_labels[index]

    
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1=tf.keras.layers.Dense(units=100,activation=tf.nn.relu)
        self.dense2=tf.keras.layers.Dense(units=10)
        
    def call(self,inputs):
        x1=self.dense1(inputs)
        out=self.dense2(x1)
        return out
    
    def predict(self,inputs):#???????
        logits=self(inputs)
        return tf.argmax(logits,axis=-1)

    
learning_rate=0.001
num_batches=1000
batch_size=50

model=MLP()
data_loader=DataLoader()
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)

for batch_index in range(num_batches):
    X,y=data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_logit_pred=model(tf.convert_to_tensor(X))
        loss=tf.losses.sparse_softmax_cross_entropy(labels=y,logits=y_logit_pred)
        print("batch%d:loss%f"%(batch_index,loss.numpy()))
    grads=tape.gradient(loss,model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))

num_eval_samples=np.shape(data_loader.eval_labels)[0]
y_pred=model.predict(data_loader.eval_data).numpy()
print("test accuracy:%f"%(sum(y_pred==data_loader.eval_labels)/num_eval_samples))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值