Github
目标方程
y = 3 * x + 2
- 3 为 Weights
- 2 为 Baises
代码
/**
* 预测线性方程 y = 3x + 2 的参数
* 其中 3 为 Weights, 2 为 Biases
*/
const x_data = dl.tensor1d([0, 1, 2, 3, 4]);
const a = dl.scalar(3)
const b = dl.scalar(2)
const y_data = x_data.mul(a).add(b)
const Weights = dl.variable(dl.randomUniform([1]))
const Biases = dl.variable(dl.zeros([1]))
const f = x => Weights.mul(x).add(Biases);
const loss = (pred, label) => pred.sub(label).square().mean()
const learningRate = 0.01
const optimizer = dl.train.sgd(learningRate)
for (let i = 0; i < 500; i++) {
optimizer.minimize(() => loss(f(x_data), y_data))
}
console.log(
`Weights: ${Weights.dataSync()}, Biases: ${Biases.dataSync()}`
)
效果