python-torch:实现拟合任意函数
前言
本文以下面的项目需求为例,进行讲解。
1. 需求
现有1w多组数据对(P,T,V),需要去拟合b0-b1。
2. 问题解决
2.1 数据读入并处理
import pandas as pd
import numpy as np
def read_data(file_path):
df = pd.read_csv(file_path)
T = np.array(df['Ta'], dtype=float)
V = np.array(df['Vw'], dtype=float)
P = np.array(df['PV'], dtype=float)
return T, V, P
2.2 定义模型
import torch
import torch.nn as nn
class Demo(nn.Module):
def __init__(self):
super(Demo, self).__init__()
self.b0 = nn.Parameter(torch.Tensor([0.0]), requires_grad = True)
self.b1 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
self.b2 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
self.b3 = nn.Parameter(torch.Tensor([1.0]), requires_grad = True)
self.b4 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
def forward(self, T, V):
P_hat = self.b0 + self.b1 * T + self.b2 / (self.b3 + self.b4 * V)
return P_hat
2.3 主要过程、前向后向
import os
import sys
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from utils import *
from model import Demo
import random
# read_data
file_path = './data/data.csv'
T, V, P = read_data(file_path)
T, V, P = torch.FloatTensor(T), torch.FloatTensor(V), torch.FloatTensor(P)
# train : valid = 4 : 1
n = T.shape[0]
s = [[i] for i in range(n)]
random.shuffle(s)
spl = int(n * 0.8)
T_train, V_train, P_train = T[s[:spl]], V[s[:spl]], P[s[:spl]]
T_valid, V_valid, P_valid = T[s[spl:]], V[s[spl:]], P[s[spl:]]
# Model Defination
model = Demo()
# Loss_Fuction Defination
loss_fn = nn.MSELoss(reduction="mean")
# Base Parameters
N_epoch = 500
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
min_loss = 1e5
B = (0.0, 0.0, 0.0, 0.0, 0.0)
model.train()
for epoch in range(N_epoch):
# train
P_hat = model(T_train, V_train)
loss = loss_fn(P_hat, P_train)
print(f"epooch {epoch}...train_loss: {loss}, min_valid_loss: {min_loss}")
optimizer.zero_grad()
loss.backward()
optimizer.step()
# valid
if epoch % 20 == 0 or epoch == N_epoch - 1:
model.eval()
P_hat = model(T_valid, V_valid)
loss = loss_fn(P_hat, P_valid)
if loss < min_loss:
min_loss = loss
B = (model.b0.data[0], model.b1.data[0], model.b2.data[0], model.b3.data[0], model.b4.data[0])
model.train()
# 获得拟合参数
print(B)