案例:鸢尾花(iris)分类
操作步骤
- 加载IRIS数据集(训练集与验证集)
- 定义模型结构:带有softmax的多层神经网络
- 初始化一个神经网络模型
- 为神经网络模型添加两个层
- 设计层的神经元个数,inputShape,激活函数
- 训练模型并预测
- 交叉熵损失函数与准确度度量
主要示例代码:
<!-- index.html -->
<form action="" onsubmit="predict(this); return false;">
花萼长度:<input type="text" name="a"><br>
花萼宽度:<input type="text" name="b"><br>
花瓣长度:<input type="text" name="c"><br>
花瓣宽度:<input type="text" name="d"><br>
<button type="submit">预测</button>
</form>
// index.js
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async() => {
//分别代表训练集和验证集的特征和标签
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15); // 15%的数据用于验证集
// xTrain.print();
// yTrain.print();
// xTest.print();
// yTest.print();
// console.log(IRIS_CLASSES);
// 定义模型结构
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape:[xTrain.shape[1]], // 特征长度:4
activation: 'sigmoid'
}));
model.add(tf.layers.dense({
units: 3,
activation:'softmax'
}));
model.compile({
loss:'categoricalCrossentropy',
optimizer: tf.train.adam(0.1),
metrics: ['accuracy']
});
await model.fit(xTrain, yTrain, {
epochs: 100,
validationData: [xTest, yTest],
callbacks: tfvis.show.fitCallbacks(
{name:'训练效果'},
['loss','val_loss','acc','val_acc'],
{callbacks:['onEpochEnd']}
)
});
window.predict = (form) => {
const input = tf.tensor([[
form.a.value *1,
form.b.value *1,
form.c.value *1,
form.d.value *1,
]]);
const pred = model.predict(input);
alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`)
}
};