多分类介绍
多分类由二分类问题推广而来,我们可以把N分类问题分解为N个2分类问题。下面我们用代码实现一个简单三分类问题,其中y为n行3列的矩阵,其中0表示不属于该类1表示属于该类。代码中用到的矩阵乘法,不会的同学自行补课。
代码实现
# coding=utf-8
import random
import matplotlib.pyplot as plt
import numpy as np
x, y = [], []
x_test1, x_test2, x_test3 = [], [], []
# 随机生成3种不同分类的点,分别打上标签存在y中
for i in range(0, 20):
x1 = random.random()
x2 = random.random()
if x1 + x2 < 1:
x.append([x1, x2, 1])
x_test1.append([x1, x2])
y.append([1, 0, 0])
x.append([x1 * 2, x2 + 1,