广西民族大学高级人工智能课程—头歌实践教学实践平台-神经网络作业

代码文件

# -*- coding: utf-8 -*-
import tensorflow as tf

def X_W(x):
    with tf.variable_scope("shared_weights", reuse=tf.AUTO_REUSE):
        W = tf.get_variable("W", [392, 10])  # Assuming output is 10 classes
    out = tf.matmul(x, W)
    return out


def mnist_predict(mnist, X1, X2, batch_size, n_iters):
    Y = tf.placeholder(tf.float32, [None, 10])  # Assuming 10 classes for MNIST
    out1 = X_W(X1)
    out2 = X_W(X2)
    combined_out = tf.add(out1, out2)
    b = tf.Variable(tf.zeros([10]))
    logits = tf.nn.softmax(combined_out + b)

    # Define loss and optimizer
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=logits))
    optimizer = tf.train.AdamOptimizer().minimize(loss)

    # Calculate accuracy
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # Start a session
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Training
        for _ in range(n_iters):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={X1: batch_xs[:, :392], X2: batch_xs[:, 392:], Y: batch_ys})

        # Testing
        test_accuracy = sess.run(accuracy, feed_dict={X1: mnist.test.images[:, :392], X2: mnist.test.images[:, 392:], Y: mnist.test.labels})

    return test_accuracy

题目描述

任务描述

本关任务:设计变量共享网络进行MNIST分类。

相关知识

为了完成本关任务,你需要掌握:1.网络结构。

网络结构

其将图片样本分为上下两半X1​,X2​;分别送入input1​,input2​。后续的两个路径的线性加权模块XW​=X∗W共享一个变量name='w'。整个分类模型可描述为:

softmax(XW​(X1​)+XW​(X2​)+b)

编程要求

1.线性加权模块XW​需定义为一个函数,在此函数中创建并共享变量W name='w'

函数X_W(X)只有一个输入参数XW必须在X_W(X)中用get_variabel定义:

 
  1. def X_W(X):
  2. ...
  3. return tf.matmul(X,W)

2.训练后精度大于0.7

测试说明

程序会调用你实现的方法对MNIST数据进行训练并预测,正确率大于0.7则视为通过。


开始你的任务吧,祝你成功!

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值