使用Mindspore拟合f(x)=a*sin(x)+b这类函数

根据mindspore线性拟合官方案例改编而成

f(x)=w * sin(x) + b,下面脚本展示将以 f(x)=2 * sin(x) +3为实例。

  1. #导入所需的工具包

    import numpy as np

    from mindspore import dataset as ds

    from mindspore.common.initializer import Normal

    from mindspore import nn

    from mindspore.train import Model

    from mindspore.train.callback import LossMonitor

    from mindspore import context

    #确定运行的硬件平台

    context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

  2. #设定要拟合的函数

    def get_data(num, w=2.0, b=3.0):

        # f(x)=w * sin(x) + b

        # f(x)=2 * sin(x) +3

        for i in range(num):

            x = np.random.uniform(-np.pi, np.pi)

            noise = np.random.normal(0, 1)

            y = w * np.sin(x) + b + noise

            yield np.array([np.sin(x)]).astype(np.float32), np.array([y]).astype(np.float32)

  3. #初始化开始的参数

    class LinearNet(nn.Cell):

        def __init__(self):

            super(LinearNet, self).__init__()

            self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))

        def construct(self, x):

            x = self.f

  4. #开始拟合,设定初始化超参数。

    if __name__ == "__main__":

        num_data = 1600

        batch_size = 16

        repeat_size = 1

        lr = 0.005

        momentum = 0.9

         

        net = LinearNet()

        net_loss = nn.loss.MSELoss()

        opt = nn.Momentum(net.trainable_params(), lr, momentum)

        model = Model(net, net_loss, opt)

         

        ds_train = create_dataset(num_data, batch_size=batch_size, repeat_size=repeat_size) 

        model.train(1, ds_train, callbacks=LossMonitor(), dataset_sink_mode=False)

         

        print(net.trainable_params()[0], "\n%s" % net.trainable_params()[1])

以下是一个使用 Qt 绘制二次函数曲线并计算拟合值 R2 的示例程序: ``` #include <QtWidgets> class QuadraticFunctionWidget : public QWidget { public: QuadraticFunctionWidget(QWidget *parent = nullptr) : QWidget(parent) { setWindowTitle(tr("Quadratic Function")); // 设置初始参数 a = 1.0; b = 0.0; c = 0.0; // 创建绘图窗口 chartView = new QChartView(this); chartView->setRenderHint(QPainter::Antialiasing); // 创建数据序列 series = new QScatterSeries(); for (int i = 0; i < 8; i++) { QPointF point(i, 0.0); // x 坐标从 0 到 7 series->append(point); } // 添加数据序列到图表 chart = new QChart(); chart->addSeries(series); chart->createDefaultAxes(); chart->setAnimationOptions(QChart::SeriesAnimations); // 显示图表 chartView->setChart(chart); chartView->setMinimumSize(640, 480); // 创建参数输入框和计算按钮 aLineEdit = new QLineEdit(QString::number(a)); bLineEdit = new QLineEdit(QString::number(b)); cLineEdit = new QLineEdit(QString::number(c)); calculateButton = new QPushButton(tr("Calculate")); // 连接计算按钮的槽函数 connect(calculateButton, &QPushButton::clicked, this, &QuadraticFunctionWidget::calculate); // 创建布局 QVBoxLayout *layout = new QVBoxLayout(); layout->addWidget(chartView); layout->addWidget(new QLabel(tr("y = ax^2 + bx + c"))); QHBoxLayout *paramLayout = new QHBoxLayout(); paramLayout->addWidget(new QLabel(tr("a:"))); paramLayout->addWidget(aLineEdit); paramLayout->addWidget(new QLabel(tr("b:"))); paramLayout->addWidget(bLineEdit); paramLayout->addWidget(new QLabel(tr("c:"))); paramLayout->addWidget(cLineEdit); paramLayout->addWidget(calculateButton); layout->addLayout(paramLayout); setLayout(layout); } private: QChartView *chartView; QChart *chart; QScatterSeries *series; QLineEdit *aLineEdit; QLineEdit *bLineEdit; QLineEdit *cLineEdit; QPushButton *calculateButton; double a; double b; double c; void calculate() { // 读取参数 a = aLineEdit->text().toDouble(); b = bLineEdit->text().toDouble(); c = cLineEdit->text().toDouble(); // 计算函数曲线 QVector<QPointF> points; for (int i = 0; i < 8; i++) { double x = i; double y = a * x * x + b * x + c; points.append(QPointF(x, y)); } // 更新数据序列 series->replace(points); // 计算拟合值 R2 double sumX = 0.0; double sumY = 0.0; double sumXY = 0.0; double sumX2 = 0.0; double sumY2 = 0.0; int n = points.size(); for (int i = 0; i < n; i++) { double x = points[i].x(); double y = points[i].y(); sumX += x; sumY += y; sumXY += x * y; sumX2 += x * x; sumY2 += y * y; } double meanX = sumX / n; double meanY = sumY / n; double SSE = sumY2 - (sumY * sumY) / n; double SSR = (sumXY - sumX * sumY / n) * (sumXY - sumX * sumY / n) / (sumX2 - sumX * sumX / n); double SST = SSE + SSR; double R2 = SSR / SST; // 显示拟合值 R2 QMessageBox::information(this, tr("R2"), tr("R2 = %1").arg(R2)); } }; int main(int argc, char *argv[]) { QApplication app(argc, argv); QuadraticFunctionWidget widget; widget.show(); return app.exec(); } ``` 该程序创建了一个窗口,包含一个图表和三个参数输入框和一个计算按钮。用户可以输入二次函数的三个参数,并点击计算按钮来绘制函数曲线和计算拟合值 R2。在计算拟合值 R2 时,程序使用了最小二乘法的公式来计算。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值