机器学习与深度学习对鸢尾花数据集分类对比

1.鸢尾花数据集介绍

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据样本,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

探索性分析结果:

2.完整程序

import seaborn as sns
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from lightgbm.sklearn import LGBMClassifier 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Activation
from tensorflow.keras import utils
import os

# 获取数据集
iris =sns.load_dataset('iris')
iris.head()

# seaborn中的pairplot探索数据的关系
sns.pairplot(iris,hue='species')

# 划分数据集,读取特征值和目标值
X=iris.values[:,:4]
y=iris.values[:,4]
# 划分数据集,训练和测试数据集
train_X,test_X,train_y,test_y=train_test_split(X,y,train_size=0.5,random_state=0)
# 查看数据集规模
train_X.shape,X.shape

# sklearn机器学习模型实例化,这里是逻辑回归和lightGBM
# 实例化模型
lr=LogisticRegressionCV()
gbm=LGBMClassifier() 
# 模型训练
lr.fit(train_X,train_y)
gbm.fit(train_X,train_y)



#############################深度学习
# 是否采用GPU看个人,此处用不用也无所谓
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  ##表示使用GPU编号为0的GPU进行计算
os.environ["CUDA_VISIBLE_DEVICES"]="0"  ##表示使用GPU编号为0的GPU进行计算

# 目标的one-hot编码
def one_hot_encode(arr):
    # 获取目标中的所有列别并进行编码
    uniques,ids=np.unique(arr,return_inverse=True)
    return utils.to_categorical(ids,len(uniques))

# 对目标进行编码
train_y_ohe=one_hot_encode(train_y)
test_y_ohe=one_hot_encode(test_y)

# 模型构建
model=Sequential([
    # 隐藏层
    Dense(10,activation='relu',input_shape=(4,)),
    # 影藏层
    Dense(10,activation='relu'),
    # 输出层
    Dense(3,activation='softmax')
])

# 模型编译
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
# 类型转换
train_X=np.array(train_X,dtype=np.float32)
test_X=np.array(test_X,dtype=np.float32)

# 模型训练
model.fit(train_X,train_y_ohe,epochs=10,batch_size=1,verbose=1)
# 模型评估
loss,accuracy=model.evaluate(test_X,test_y_ohe,verbose=1)
print('深度学习准确率',accuracy)
print('逻辑回归准确率',lr.score(test_X,test_y),gbm.score(test_X,test_y))
print('lightGBM准确率',gbm.score(test_X,test_y),gbm.score(test_X,test_y))

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值