个人手工实现的神经网络,可以训练出XOR逻辑模型。拷贝到本地后可以在CUP上训练。
随机初始化的weights对是否能训练成功有影响,可以跑多次进行测试。

# learning XOR throw a two layer network use gradient descent
# w 的初始化值对是否能训练成功和训练的速度都有较大影响。这是无限模型带来的问题
import math
import random
import numpy as np
from datetime import datetime

# random.seed(3)
random.seed(datetime.now())

weightSet = set()


def random_weight():
    while (True):
        w = (random.randint(-10, 10)) / 10
        if w == 0:
            continue
        if w not in weightSet:
            weightSet.add(w)
            return w


class Weights:
    indexMap = {"b1": 0, "w11": 1, "w12": 2, "b2": 3, "w21": 4, "w22": 5, "c": 6, "v1": 7, "v2": 8}
    weights = []

    def __init__(self):
        self.weights = [-1.5, 1, 1, 0.5, -1, -1, 0.5, -1, -1]

    def getIndex(self, name):
        return self.indexMap[name]

    def get(self, index):
        return self.weights[index]

    def getWeights(self):
        return self.weights

    def update(self, index, w):
        self.weights[index] = w


η = 0.05

weights = Weights()
for i in range(len(weights.getWeights())):
    weights.update(i, random_weight())

# "b1", "w11", "w12", "b2", "w21", "w22", "c", "v1", "v2"
# weights.weights = [-1.5, 1, 1, 0.5, -1, -1, 0.5, -1, -1]  # target
# target:
# weights.weights = [-3.6256294452213225, 2.472015878775767, 2.472015878775767, 1.4097549175948436, -3.0465380385550866,
#                    -3.0465380385550866, -3.7807463009027975, -4.221045218825591, -4.194984747863033]

# weights.weights = [0.5,-1,1,0.5,-1,-1,1.5,-1,-1]
# weights.weights = [-1.0, 0, 1, 0.5, -1, -1, 0.5, -1, -1]
# weights.weights = [0.0, 0.0, 0.0,  0.0, 0.0,  0.0, 0.0, 0.0, 0.0] # 不能初始化为0

inputs = [[1, 1], [1, 0], [0, 1], [0, 0]]
ts = [0, 1, 1, 0]


def cost():
    return


λ = 0.001  # weight decay λ


def pW(weights, λ):
    r = 0
    for w in weights.weights:
        r = r + w * w
    r = (r * λ) / 2
    return r


def sigmoid(num):
    # return step(num)
    # num = num * 5
    r = 1 / (1 + math.exp(-num))
    return r


def tanh(num) -> object:
    r = 2 / (1 + math.exp(-2 * num)) - 1
    return r


def step(num):
    if num > 0:
        return 1
    else:
        return 0


def activation(num):
    # if(num)>0:
    #     return 1
    # else:
    #     return 0
    return sigmoid(num)


def calculate_s(inputs, weights):
    result = weights[0]
    for i in range(len(inputs)):
        x = inputs[i]
        result = result + x * weights[i + 1]
    return result


def c_z(x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    return z


def c_z_step(x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = (x1 * w11 + x2 * w12 + b1)
    if y1 > 0:
        y1 = 1
    else:
        y1 = 0
    y2 = (x1 * w21 + x2 * w22 + b2)
    if y2 > 0:
        y2 = 1
    else:
        y2 = 0
    s = y1 * v1 + y2 * v2 + c
    if s > 0:
        s = 1
    else:
        s = 0
    return s


def calculate_dE_dz(t, z, weights):
    dE_dz = z - t
    # dE_dz = dE_dz + λ * sum(weights.weights)
    return dE_dz


def calculatePartialDerivatives_c(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dE_dc = dE_ds
    return dE_dc


def calculatePartialDerivatives_v1(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dE_dv1 = dE_ds * y1
    return dE_dv1


def calculatePartialDerivatives_v2(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds
    dE_dv2 = dout * y2
    return dE_dv2


def calculatePartialDerivatives_b1(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds
    dE_dy1 = dout * v1
    dy1_du1 = 1 - y1 ** 2
    dE_du1 = dE_dy1 * dy1_du1
    du1_db1 = 1
    dE_db1 = dE_du1 * du1_db1
    return dE_db1


def calculatePartialDerivatives_w11(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c

    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds

    dE_dy1 = dout * v1

    dy1_du1 = 1 - y1 ** 2
    dE_du1 = dE_dy1 * dy1_du1

    du1_dw11 = x1
    dE_dw11 = dE_du1 * du1_dw11

    return dE_dw11


def calculatePartialDerivatives_w12(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c

    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds

    dE_dy1 = dout * v1

    dy1_du1 = 1 - y1 ** 2
    dE_du1 = dE_dy1 * dy1_du1

    du1_dw12 = x2
    dE_dw12 = dE_du1 * du1_dw12

    return dE_dw12


def calculatePartialDerivatives_b2(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds
    dE_dy2 = dout * v2
    dy2_du2 = 1 - y2 ** 2
    dE_du2 = dE_dy2 * dy2_du2
    du2_db2 = 1
    dE_db2 = dE_du2 * du2_db2
    return dE_db2


def calculatePartialDerivatives_w21(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds
    dE_dy2 = dout * v2
    dy2_du2 = 1 - y2 ** 2
    dE_du2 = dE_dy2 * dy2_du2
    du2_dw21 = x1
    dE_dw21 = dE_du2 * du2_dw21

    return dE_dw21


def calculatePartialDerivatives_w22(t, x1, x2, weights):
    b1, w11, w12, b2, w21, w22, c, v1, v2 = weights.getWeights()

    y1 = tanh(x1 * w11 + x2 * w12 + b1)
    y2 = tanh(x1 * w21 + x2 * w22 + b2)
    s = y1 * v1 + y2 * v2 + c
    z = activation(s)
    dE_dz = calculate_dE_dz(t, z, weights)
    dz_ds = z * (1 - z)
    dE_ds = dE_dz * dz_ds
    dout = dE_ds
    dE_dy2 = dout * v2
    dy2_du2 = 1 - y2 ** 2
    dE_du2 = dE_dy2 * dy2_du2
    du2_dw22 = x2
    dE_dw22 = dE_du2 * du2_dw22
    return dE_dw22


def calculatePartialDerivatives_sum(inputs, weights, calculatePartialDerivatives_single):
    d = 0
    for i in range(len(inputs)):
        input = inputs[i]
        t = ts[i]
        x1 = input[0]
        x2 = input[1]
        d = d + calculatePartialDerivatives_single(t, x1, x2, weights)
        # print(x1, x2, d)
    return d


calulatePartialDerivativesFunctions = [
    calculatePartialDerivatives_b1,
    calculatePartialDerivatives_w11,
    calculatePartialDerivatives_w12,
    calculatePartialDerivatives_b2,
    calculatePartialDerivatives_w21,
    calculatePartialDerivatives_w22,
    calculatePartialDerivatives_c,
    calculatePartialDerivatives_v1,
    calculatePartialDerivatives_v2]


def getCalulatePartialDerivativesFunction(index):
    return calulatePartialDerivativesFunctions[index]


def calulateError(inputs, ts, weightIndex, w, weights):
    weights.update(weightIndex, w)
    # print(weights.__dict__)
    E = 0
    for i in range(len(inputs)):
        input = inputs[i]
        t = ts[i]
        x1 = input[0]
        x2 = input[1]
        q = c_z(x1, x2, weights)

        E = E + ((q - t) ** 2)

        # p = activation(t)
        # D_KL = (p * (math.log(p, 2) - math.log(q, 2)))
        # E = E + D_KL

        # print((z-t), (z-t)**2)
    E = E / 2
    # E = E + pW(weights, λ)  # Weight decay
    return E


epochs = 10000
# epochs = 10


def adjustWeights():
    weightNames = ["b1", "w11", "w12", "b2", "w21", "w22", "c", "v1", "v2"]
    # weightNames = ["w11", "w12"]
    weightIndexs = []
    for weightName in weightNames:
        i = weights.getIndex(weightName)
        weightIndexs.append(i)
    derivatives = [0] * 9
    for i in range(epochs):
        # for i in range(1):
        # 计算完所有的derivatives之后再更新ws或者每次都更新每个w都可以。
        # 不能保证解决平原,马鞍等问题。但因为初始的w都比较小,从而很大概率?上避开了这些问题
        for weightIndex in weightIndexs:
            cF = getCalulatePartialDerivativesFunction(weightIndex)
            d = calculatePartialDerivatives_sum(inputs, weights, cF)
            derivatives[weightIndex] = d
            # beforeW = weights.get(weightIndex)
            # weights.update(weightIndex, beforeW - η * d)

        for weightIndex in weightIndexs:
            d = derivatives[weightIndex]
            beforeW = weights.get(weightIndex)
            weights.update(weightIndex, beforeW - η * d)
        print("middle:", weights.__dict__)


def check():
    E = 0
    for i in range(len(inputs)):
        input = inputs[i]
        t = ts[i]
        x1 = input[0]
        x2 = input[1]
        z = c_z(x1, x2, weights)
        E = E + (z - t) ** 2
        print("[{0},{1}] target: {2}, actual:{3}, E:{4}".format(x1, x2, t, z, E))
    print("E", E / 2)


if __name__ == '__main__':
    # weights.update(weights.getIndex("w21"), 10)
    print("start:", weights.__dict__)
    # check()
    adjustWeights()
    print("final:", weights.__dict__)
    check()

  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.
  • 140.
  • 141.
  • 142.
  • 143.
  • 144.
  • 145.
  • 146.
  • 147.
  • 148.
  • 149.
  • 150.
  • 151.
  • 152.
  • 153.
  • 154.
  • 155.
  • 156.
  • 157.
  • 158.
  • 159.
  • 160.
  • 161.
  • 162.
  • 163.
  • 164.
  • 165.
  • 166.
  • 167.
  • 168.
  • 169.
  • 170.
  • 171.
  • 172.
  • 173.
  • 174.
  • 175.
  • 176.
  • 177.
  • 178.
  • 179.
  • 180.
  • 181.
  • 182.
  • 183.
  • 184.
  • 185.
  • 186.
  • 187.
  • 188.
  • 189.
  • 190.
  • 191.
  • 192.
  • 193.
  • 194.
  • 195.
  • 196.
  • 197.
  • 198.
  • 199.
  • 200.
  • 201.
  • 202.
  • 203.
  • 204.
  • 205.
  • 206.
  • 207.
  • 208.
  • 209.
  • 210.
  • 211.
  • 212.
  • 213.
  • 214.
  • 215.
  • 216.
  • 217.
  • 218.
  • 219.
  • 220.
  • 221.
  • 222.
  • 223.
  • 224.
  • 225.
  • 226.
  • 227.
  • 228.
  • 229.
  • 230.
  • 231.
  • 232.
  • 233.
  • 234.
  • 235.
  • 236.
  • 237.
  • 238.
  • 239.
  • 240.
  • 241.
  • 242.
  • 243.
  • 244.
  • 245.
  • 246.
  • 247.
  • 248.
  • 249.
  • 250.
  • 251.
  • 252.
  • 253.
  • 254.
  • 255.
  • 256.
  • 257.
  • 258.
  • 259.
  • 260.
  • 261.
  • 262.
  • 263.
  • 264.
  • 265.
  • 266.
  • 267.
  • 268.
  • 269.
  • 270.
  • 271.
  • 272.
  • 273.
  • 274.
  • 275.
  • 276.
  • 277.
  • 278.
  • 279.
  • 280.
  • 281.
  • 282.
  • 283.
  • 284.
  • 285.
  • 286.
  • 287.
  • 288.
  • 289.
  • 290.
  • 291.
  • 292.
  • 293.
  • 294.
  • 295.
  • 296.
  • 297.
  • 298.
  • 299.
  • 300.
  • 301.
  • 302.
  • 303.
  • 304.
  • 305.
  • 306.
  • 307.
  • 308.
  • 309.
  • 310.
  • 311.
  • 312.
  • 313.
  • 314.
  • 315.
  • 316.
  • 317.
  • 318.
  • 319.
  • 320.
  • 321.
  • 322.
  • 323.
  • 324.
  • 325.
  • 326.
  • 327.
  • 328.
  • 329.
  • 330.
  • 331.
  • 332.
  • 333.
  • 334.
  • 335.
  • 336.
  • 337.
  • 338.
  • 339.
  • 340.
  • 341.
  • 342.
  • 343.
  • 344.
  • 345.
  • 346.
  • 347.
  • 348.
  • 349.
  • 350.
  • 351.
  • 352.
  • 353.
  • 354.
  • 355.
  • 356.
  • 357.
  • 358.
  • 359.
  • 360.
  • 361.
  • 362.
  • 363.
  • 364.
  • 365.
  • 366.
  • 367.
  • 368.
  • 369.
  • 370.
  • 371.
  • 372.
  • 373.
  • 374.
  • 375.
  • 376.
  • 377.
  • 378.
  • 379.
  • 380.
  • 381.
  • 382.
  • 383.
  • 384.
  • 385.
  • 386.
  • 387.
  • 388.
  • 389.
  • 390.
  • 391.
  • 392.
  • 393.
  • 394.
  • 395.
  • 396.
  • 397.
  • 398.
  • 399.
  • 400.
  • 401.
  • 402.
  • 403.
  • 404.
  • 405.
  • 406.
  • 407.
  • 408.
  • 409.
  • 410.
  • 411.
  • 412.
  • 413.
  • 414.
  • 415.
  • 416.
  • 417.
  • 418.
  • 419.
  • 420.