代码文件
# -*- coding: utf-8 -*-
import tensorflow as tf
def X_W(x):
with tf.variable_scope("shared_weights", reuse=tf.AUTO_REUSE):
W = tf.get_variable("W", [392, 10]) # Assuming output is 10 classes
out = tf.matmul(x, W)
return out
def mnist_predict(mnist, X1, X2, batch_size, n_iters):
Y = tf.placeholder(tf.float32, [None, 10]) # Assuming 10 classes for MNIST
out1 = X_W(X1)
out2 = X_W(X2)
combined_out = tf.add(out1, out2)
b = tf.Variable(tf.zeros([10]))
logits = tf.nn.softmax(combined_out + b)
# Define loss and optimizer
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=logits))
optimizer = tf.train.AdamOptimizer().minimize(loss)
# Calculate accuracy
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Start a session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Training
for _ in range(n_iters):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={X1: batch_xs[:, :392], X2: batch_xs[:, 392:], Y: batch_ys})
# Testing
test_accuracy = sess.run(accuracy, feed_dict={X1: mnist.test.images[:, :392], X2: mnist.test.images[:, 392:], Y: mnist.test.labels})
return test_accuracy
题目描述
任务描述
本关任务:设计变量共享网络进行MNIST
分类。
相关知识
为了完成本关任务,你需要掌握:1.网络结构。
网络结构
其将图片样本分为上下两半X1,X2;分别送入input1,input2。后续的两个路径的线性加权模块XW=X∗W共享一个变量name='w'
。整个分类模型可描述为:
softmax(XW(X1)+XW(X2)+b)
编程要求
1.线性加权模块XW需定义为一个函数,在此函数中创建并共享变量W name='w'
。
函数X_W(X)
只有一个输入参数X
,W
必须在X_W(X)
中用get_variabel
定义:
def X_W(X):
...
return tf.matmul(X,W)
2.训练后精度大于0.7
。
测试说明
程序会调用你实现的方法对MNIST
数据进行训练并预测,正确率大于0.7
则视为通过。
开始你的任务吧,祝你成功!