(openCV 九)ML_ANN

import cv2
import numpy as np
from random import randint

# 通过调用OpenCV函数创建ANN
animals_net = cv2.ml.ANN_MLP_create()

# ANN_MLP_RPROP和ANN_MLP_BACKPROP都是反向传播算法,此处设置相应的拓扑结构
animals_net.setLayerSizes(np.array([3, 6, 4]))
animals_net.setTrainMethod(cv2.ml.ANN_MLP_RPROP | cv2.ml.ANN_MLP_UPDATE_WEIGHTS)
animals_net.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)

# 指定ANN的终止条件
animals_net.setTermCriteria((cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 1))

"""Input arrays
weight, length, teeth
"""

"""Output arrays
dog, eagle, dolphin and dragon
"""


def dog_sample():
    return [randint(10, 20), 1, randint(38, 42)]


def dog_class():
    return [1, 0, 0, 0]


def condor_sample():
    return [randint(3, 10), randint(3, 5), 0]


def condor_class():
    return [0, 1, 0, 0]


def dolphin_sample():
    return [randint(30, 190), randint(5, 15), randint(80, 100)]


def dolphin_class():
    return [0, 0, 1, 0]


def dragon_sample():
    return [randint(1200, 1800), randint(30, 40), randint(160, 180)]


def dragon_class():
    return [0, 0, 0, 1]


def record(sample, classification):
    return (np.array([sample], dtype=np.float32), np.array([classification], dtype=np.float32))


'''
为了提高精度,大多数ANN会迭代多个周期;一些常见的ANN示例,会对数据进行数百次迭代。
'''
records = []

RECORDS = 5000
for x in range(0, RECORDS):
    records.append(record(dog_sample(), dog_class()))
    records.append(record(condor_sample(), condor_class()))
    records.append(record(dolphin_sample(), dolphin_class()))
    records.append(record(dragon_sample(), dragon_class()))

EPOCHS = 2
for e in range(0, EPOCHS):
    print("Epoch %d:" % e)
    for t, c in records:
        animals_net.train(t, cv2.ml.ROW_SAMPLE, c)

TESTS = 100
dog_results = 0
for x in range(0, TESTS):
    clas = int(animals_net.predict(np.array([dog_sample()], dtype=np.float32))[0])
    print("class: %d" % clas)
    if (clas) == 0:
        dog_results += 1

condor_results = 0
for x in range(0, TESTS):
    clas = int(animals_net.predict(np.array([condor_sample()], dtype=np.float32))[0])
    print("class: %d" % clas)
    if (clas) == 1:
        condor_results += 1

dolphin_results = 0
for x in range(0, TESTS):
    clas = int(animals_net.predict(np.array([dolphin_sample()], dtype=np.float32))[0])
    print("class: %d" % clas)
    if (clas) == 2:
        dolphin_results += 1

dragon_results = 0
for x in range(0, TESTS):
    clas = int(animals_net.predict(np.array([dragon_sample()], dtype=np.float32))[0])
    print("class: %d" % clas)
    if (clas) == 3:
        dragon_results += 1

print("Dog accuracy: %f%%" % (dog_results))
print("condor accuracy: %f%%" % (condor_results))
print("dolphin accuracy: %f%%" % (dolphin_results))
print("dragon accuracy: %f%%" % (dragon_results))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值