44.一维卷积-叶子分类

我们用的数据集为train.csv,我们打开看一下

其中id是编号,species是种类,后面的margin1这些是他某一部分的特征,一共有三种大的特征类型(margin,shape,texture),三种特征都是从1-64(如margin1-margin64),特征以浮点数的形式记录了下来,我们使用后面的特征取预测前面的species

目录

1  导入库

2  整理数据

2.1  读取数据

2.2  数据的基本情况

2.3  提取训练输入与输出

2.4  分割训练集与测试集

2.5  标准化

2.6  增加维度

3  搭建网络

4  编译模型

5  训练模型

6  预测模型


1  导入库

2  整理数据

2.1  读取数据

首先我们读取数据

2.2  数据的基本情况

读取数据后我们看一下它的基本情况

一共990行,194列,格式有float64(浮点数)192列,int64(整数)1列(这个是id的那一列),object(字符串)1列(这个是species这一列),文件大小大致为1.5MB

2.3  提取训练输入与输出

我们一共有99种叶子,我们使用pd.factorize给每一个种类一个代表数字,它的运行结果是下面这样的,由于factorize是按照固定的顺序排列的,所以每次的运行结果都是一样的

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20,  4, 21, 22, 23, 24, 25, 26, 25, 15, 27, 28,  3, 29,
       30, 31, 32, 33, 34, 35,  4,  2, 36, 14, 37, 38, 39, 40,  6, 41, 42,
       43, 44, 45, 20, 27, 14, 32, 45, 34, 46, 29, 47, 48, 49, 50, 34, 51,
       17,  4, 27, 52, 53, 35, 54, 53, 40, 13,  6, 55, 56, 39, 43, 57, 58,
       59, 60, 34, 61, 62, 63, 62, 64, 62, 41, 65, 34, 49, 66, 11, 51,  6,
       67, 68, 69, 63,  4, 38, 23, 68, 23, 70, 36, 17, 56, 56,  3, 27, 71,
       46, 72, 73, 74, 52, 33, 75, 70, 76, 47, 77, 62, 78, 79, 67, 64, 77,
       32, 56, 18, 45, 12, 61, 34, 59, 41,  5, 35, 15, 20, 13, 54, 80, 81,
       22, 62, 16, 82, 37,  2,  4, 75, 69, 65, 11, 72, 31, 30, 83, 19, 38,
       66, 18, 84, 75, 80, 82, 81, 57, 30, 46,  0, 20, 45, 24, 85, 17, 59,
       65,  4, 71, 54, 82, 24, 38, 86, 15, 10,  5, 87, 56, 11, 88, 89, 90,
       59, 36, 90, 54, 91,  7, 86, 52, 33, 57, 71, 58, 41, 74, 19, 35, 21,
       86, 92, 21, 91, 75, 39, 35,  7, 48, 56, 78,  7, 84, 56, 81, 86, 52,
       69, 32, 58, 24, 48, 77,  5, 23, 51, 87, 67, 54, 48, 28,  7, 23, 93,
       35, 32, 77, 27,  1, 12, 80, 94, 15, 50, 78, 11, 20, 89, 91, 88, 26,
       12,  8, 23, 19, 43, 65,  3, 38, 87, 51, 73, 23, 67, 73, 91, 63, 69,
       30, 56, 92, 49, 95, 36, 11, 72, 89, 20, 76, 30,  8, 62, 80, 70, 72,
       27, 88, 68, 41, 40, 63, 24, 26, 90, 15, 56, 59, 16, 18, 96, 73, 90,
       83, 78, 51, 10, 27, 22, 42, 25, 25, 49, 64, 84, 32, 33, 28, 60, 39,
       19, 78,  5, 91, 89, 34, 52, 70, 58, 28, 33, 52, 63, 82, 32, 59, 66,
       24, 93,  7, 13,  6, 59, 67, 61, 69, 49, 63, 66, 16, 70, 37, 87, 77,
       22, 77,  9, 87, 33, 15,  2, 68, 74, 90, 14, 88, 69, 73, 94, 67, 58,
        8, 51, 86, 68,  3, 58, 58, 44, 19, 12, 81, 60, 26, 55, 97, 64, 96,
       83, 11, 16, 65, 98, 73, 31, 50, 96, 44, 82, 48, 31,  4, 95, 78, 85,
       47, 79, 29,  9, 82, 74, 29, 84, 69, 58, 98, 84, 86, 10, 31, 74, 27,
       95, 12, 79, 40, 95, 88, 22, 26, 21, 19, 98,  8, 26, 11,  5,  5, 31,
       95, 24, 35, 94, 37, 11, 13, 66,  1, 41, 39, 76, 89, 35, 93, 31, 81,
       15, 85, 22, 29, 89, 20, 44, 83, 21, 47, 28, 94, 27, 53, 66, 98, 26,
        2, 54, 22, 55, 51, 72, 18, 71, 57, 59, 80, 45, 76, 82, 65, 45, 50,
       89, 76,  6, 52, 90, 33, 73, 83, 68, 14, 63, 31,  9,  6, 40, 85, 97,
       36, 58, 71, 94, 71, 50, 42, 12, 40, 12, 76, 41, 43,  4, 98, 30, 83,
       50, 44, 79,  5, 25, 92, 55, 27, 97, 37, 31, 69, 18, 16, 72, 55, 43,
        8, 64, 46, 46, 91, 64, 81,  0, 29, 59, 61, 25, 38, 32, 57, 38, 75,
       97, 20, 46, 65, 43, 93, 29, 83, 88, 37, 19, 75, 50,  1, 89, 73,  3,
       36,  6, 57, 49, 16, 30, 13, 57, 40, 48, 17,  4, 62, 61,  1, 74, 39,
       19, 21, 73, 25,  9,  7, 80, 84, 55, 69, 43, 36, 70, 39, 10, 19, 85,
       66,  8, 62, 78,  0, 41, 21, 72, 22, 63, 49, 24, 87, 16, 20, 74, 62,
       30,  1, 54,  9, 14, 67, 71, 88, 23, 77, 76, 23, 67, 17, 96, 29, 76,
        5, 76, 87, 47,  3, 65, 48, 45, 74, 35,  9,  1,  9, 81,  2, 14, 17,
       93, 17, 38, 12, 54, 32,  6, 21, 11, 53, 97, 46, 82, 33, 53, 79, 30,
       39, 92, 96, 46, 86,  0,  7,  2, 59, 28, 13, 94, 10, 35, 53, 14, 77,
       75, 67, 93, 79, 20, 13,  0, 98, 92, 79, 44, 82, 36, 37, 38,  2, 42,
       84, 48, 34, 98, 39, 95, 41, 41, 66, 85, 80, 70, 52, 10, 98, 96, 15,
        8, 71, 73, 85, 83, 87, 79, 55, 83, 71, 64, 96, 49, 78, 44, 60, 84,
       93, 12, 55, 68, 46, 94, 31, 83, 68, 79, 97, 79, 67, 47, 88, 97, 84,
       86, 26, 14, 30, 54, 28, 10, 53, 98,  0, 16, 93, 82, 18, 34, 45, 89,
       80, 64, 74, 43, 55, 15, 65, 46, 22, 93, 18, 60,  3, 21, 18, 94, 51,
       13, 60, 40, 75, 51,  0, 43, 92, 42, 61, 25, 44,  1, 97, 16,  3, 92,
       97, 43, 50, 72, 81, 37, 61, 47, 22, 61, 52, 29, 97, 28, 75,  6, 49,
       47, 80,  2, 39, 17, 74, 81, 40, 90, 70, 66, 91, 80, 94, 71, 33, 95,
       14, 52, 28, 87, 77, 98, 21, 37, 92, 49, 92, 26,  1, 10, 42, 69, 42,
       57, 89, 60, 91,  8, 34, 33,  7, 68,  2, 61, 84, 85, 96, 10, 60, 54,
       94, 40,  0, 66, 36, 56, 95, 28, 26, 47, 24, 42, 53, 32,  9, 87, 60,
       53, 91, 68, 65, 44, 48, 77, 23, 25, 75, 25, 42, 88, 50, 61, 64,  0,
       88, 62,  8, 76, 96, 93, 24, 81, 90, 44,  1, 72, 58, 78, 38, 48, 70,
       37, 36, 63, 91, 96, 51, 57, 63, 85,  9,  3, 78, 42, 90, 45, 86, 17,
       92, 64, 50, 72, 85, 95, 29, 70, 18, 86, 90, 45, 60, 95, 53, 47,  5,
       57, 55,  7, 13], dtype=int64), Index(['Acer_Opalus', 'Pterocarya_Stenoptera', 'Quercus_Hartwissiana',
       'Tilia_Tomentosa', 'Quercus_Variabilis', 'Magnolia_Salicifolia',
       'Quercus_Canariensis', 'Quercus_Rubra', 'Quercus_Brantii',
       'Salix_Fragilis', 'Zelkova_Serrata', 'Betula_Austrosinensis',
       'Quercus_Pontica', 'Quercus_Afares', 'Quercus_Coccifera',
       'Fagus_Sylvatica', 'Phildelphus', 'Acer_Palmatum', 'Quercus_Pubescens',
       'Populus_Adenopoda', 'Quercus_Trojana', 'Alnus_Sieboldiana',
       'Quercus_Ilex', 'Arundinaria_Simonii', 'Acer_Platanoids',
       'Quercus_Phillyraeoides', 'Cornus_Chinensis', 'Liriodendron_Tulipifera',
       'Cytisus_Battandieri', 'Rhododendron_x_Russellianum', 'Alnus_Rubra',
       'Eucalyptus_Glaucescens', 'Cercis_Siliquastrum', 'Cotinus_Coggygria',
       'Celtis_Koraiensis', 'Quercus_Crassifolia', 'Quercus_Kewensis',
       'Cornus_Controversa', 'Quercus_Pyrenaica', 'Callicarpa_Bodinieri',
       'Quercus_Alnifolia', 'Acer_Saccharinum', 'Prunus_X_Shmittii',
       'Prunus_Avium', 'Quercus_Greggii', 'Quercus_Suber',
       'Quercus_Dolicholepis', 'Ilex_Cornuta', 'Tilia_Oliveri',
       'Quercus_Semecarpifolia', 'Quercus_Texana', 'Ginkgo_Biloba',
       'Liquidambar_Styraciflua', 'Quercus_Phellos', 'Quercus_Palustris',
       'Alnus_Maximowiczii', 'Quercus_Agrifolia', 'Acer_Pictum',
       'Acer_Rufinerve', 'Lithocarpus_Cleistocarpus',
       'Viburnum_x_Rhytidophylloides', 'Ilex_Aquifolium', 'Acer_Circinatum',
       'Quercus_Coccinea', 'Quercus_Cerris', 'Quercus_Chrysolepis',
       'Eucalyptus_Neglecta', 'Tilia_Platyphyllos', 'Alnus_Cordata',
       'Populus_Nigra', 'Acer_Capillipes', 'Magnolia_Heptapeta', 'Acer_Mono',
       'Cornus_Macrophylla', 'Crataegus_Monogyna', 'Quercus_x_Turneri',
       'Quercus_Castaneifolia', 'Lithocarpus_Edulis', 'Populus_Grandidentata',
       'Acer_Rubrum', 'Quercus_Imbricaria', 'Eucalyptus_Urnigera',
       'Quercus_Crassipes', 'Viburnum_Tinus', 'Morus_Nigra',
       'Quercus_Vulcanica', 'Alnus_Viridis', 'Betula_Pendula', 'Olea_Europaea',
       'Quercus_Ellipsoidalis', 'Quercus_x_Hispanica', 'Quercus_Shumardii',
       'Quercus_Rhysophylla', 'Castanea_Sativa', 'Ulmus_Bergmanniana',
       'Quercus_Nigra', 'Salix_Intergra', 'Quercus_Infectoria_sub',
       'Sorbus_Aria'],
      dtype='object'))

上面这个结果是一个元组,我们取元组的第0号元素作为我们的训练的label

除了第一列的id和第二列的species,剩下都是我们的输入,所以我们提取除前两列的所有列

2.4  分割训练集与测试集

之后我们分割出训练数据与测试数据,我们使用sklearn中的model_selecction的train_test_split方法,参数是上面定义的x与labels

  • 上面我没有加入其他的参数,默认train数据0.75,test数据0.25,数据个数取整数

2.5  标准化

2.6  增加维度

由于我们之前都使用了Embedding层,所以没有在通过网络前对训练数据的维度进行更改,一维卷积或者LSTM层都需要输出形状为(samples,feature,step)这种形状的数据

  • samples 训练数据个数
  • feature 训练数据特征数
  • step 一次看多少个

我们看一下此时train_x的shape

742是行数,192是列数,现在差一维step,所以我们要增加最后一维

我们现在再看一下

3  搭建网络

4  编译模型

5  训练模型

训练的时间不会很长,所以我们不加检查点了,但是有很多的epoch,所以我们加上tensorboard看一下情况,由于数据量不大,所以我们不设置batch了,当我们不设置batch时,所有的数据就是一个batch

训练之后保存模型

预测前我们看一下tensorboard,我们直接看acc

训练集90%以上,测试集90%左右,发现效果还可以

6  预测模型

这个是kaggle上的数据集,配套的还有一个test.csv,这个上面有除了名称之外的其他信息

我们就用这个来做预测,首先处理一下数据,然后做一个多分类预测

我们看一下结果,结果返回了一个列表,列表里都是数值,这个数值就是在训练过程中使用pd.fatorize的代表数值,我们可以将其映射一下就可以得到叶子的名字了

在test.csv中一共有594行数据

  • 第1行是id这种标题

我们看一下我们预测的结果是不是594个

发现是594个,这表明我们的预测方法没有问题

  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Suyuoa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值