感知机学习:鸢尾花二分类

感知机二分类模型:𝑓(𝑥)=sign(𝑤⋅𝑥+𝑏)
最小化损失函数:
在这里插入图片描述

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

加载数据集:

iris = load_iris()

iris:
‘target_names’: array([‘setosa’, ‘versicolor’, ‘virginica’]
‘feature_names’: [‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]

将数据以表格的形式展示:

df = pd.DataFrame(iris.data, columns=['sepal length', 'sepal width', 'petal length', 'petal width'])

在这里插入图片描述

加上标签:

df['label'] = iris.target

在这里插入图片描述
坐标图展示两类鸢尾花:

plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'],label='1')
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.legend()#显示图例

在这里插入图片描述

提取数据:

data = np.array(df.iloc[:100, [0,1,-1]])#提取前100行,第0,1 ,最后一列的数据
X, y = data[:, :-1], data[:, -1]#x取第0,1列的数据,y取最后一列的数据
y = np.array([1 if i == 1 else -1 for i in y])#将y的标签设置为1或者-1

感知机模型训练:

class Model:
    def __init__(self):
        self.w = np.ones(len(data[0])-1, dtype=np.float32)
        self.b = 0
        self.rate = 0.1
    
    def sign(self, x, w, b):
        y = np.dot(x, w) + b
        return y
    
    def fit(self, X_train, y_train):
        is_wrong = False
        while not is_wrong:
            wrong_count = 0
            for d in range(len(X_train)):
                x = X_train[d]
                y = y_train[d]
                if y * self.sign(x, self.w, self.b) <=0 :
                    self.w = self.w + self.rate * np.dot(y,x)
                    self.b = self.b + self.rate * y
                    wrong_count +=1
            if wrong_count == 0 :
                is_wrong = True
        return "success"

结果:

perceptron = Model()
perceptron.fit(X, y)

x = np.linspace(4,7,10)
y = -(perceptron.w[0] * x + perceptron.b) / perceptron.w[1]
plt.plot(x,y)

plt.plot(data[:50,0],data[:50,1], 'bo', color="blue", label="0")
plt.plot(data[50:100,0],data[50:100,1], 'bo',color="red", label="1")
plt.xlabel("sepal length")
plt.ylabel("sepal width")
plt.legend()

在这里插入图片描述

DataFrame

iloc

代码参考自

  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值