python-torch:实现拟合任意函数

python-torch:实现拟合任意函数

前言

本文以下面的项目需求为例,进行讲解。

1. 需求

在这里插入图片描述

现有1w多组数据对(P,T,V),需要去拟合b0-b1。
1

2. 问题解决

2.1 数据读入并处理

import pandas as pd
import numpy as np

def read_data(file_path):
          df = pd.read_csv(file_path)
          T = np.array(df['Ta'], dtype=float)
          V = np.array(df['Vw'], dtype=float)
          P = np.array(df['PV'], dtype=float)
          return T, V, P

2.2 定义模型

import torch
import torch.nn as nn

class Demo(nn.Module):
    def __init__(self):
        super(Demo, self).__init__()
        self.b0 = nn.Parameter(torch.Tensor([0.0]), requires_grad = True)
        self.b1 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
        self.b2 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
        self.b3 = nn.Parameter(torch.Tensor([1.0]), requires_grad = True)
        self.b4 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
        
    
    def forward(self, T, V):
        P_hat = self.b0 + self.b1 * T + self.b2 / (self.b3 + self.b4 * V)
        return P_hat

2.3 主要过程、前向后向

import os
import sys
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from utils import *
from model import Demo
import random

# read_data
file_path = './data/data.csv'
T, V, P = read_data(file_path)
T, V, P = torch.FloatTensor(T), torch.FloatTensor(V), torch.FloatTensor(P)

# train : valid = 4 : 1
n = T.shape[0]
s = [[i] for i in range(n)]
random.shuffle(s)
spl = int(n * 0.8)
T_train, V_train, P_train = T[s[:spl]], V[s[:spl]], P[s[:spl]]
T_valid, V_valid, P_valid = T[s[spl:]], V[s[spl:]], P[s[spl:]]

# Model Defination
model = Demo()

# Loss_Fuction Defination
loss_fn = nn.MSELoss(reduction="mean")

# Base Parameters
N_epoch  = 500
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
min_loss = 1e5
B = (0.0, 0.0, 0.0, 0.0, 0.0)

model.train()
for epoch in range(N_epoch):
    # train
    P_hat = model(T_train, V_train)
    loss = loss_fn(P_hat, P_train)
    print(f"epooch {epoch}...train_loss: {loss}, min_valid_loss: {min_loss}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # valid
    if epoch % 20 == 0 or epoch == N_epoch - 1:
          model.eval()
          P_hat = model(T_valid, V_valid)
          loss = loss_fn(P_hat, P_valid)
          if loss < min_loss:
                    min_loss = loss
                    B = (model.b0.data[0], model.b1.data[0], model.b2.data[0], model.b3.data[0], model.b4.data[0])
          model.train()
# 获得拟合参数
print(B)

3. 结果展示

2

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

秋冬无暖阳°

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值