一个例子让你入门PINN

PINN 入门笔记

前言

这段代码最重要的难点在于理解损失函数和物理方程残差的定义。看懂这里,你就入门了 PINN。

重点剖析

片段 1

def loss_func(self):
    u_pred = self.net_u(self.x_u, self.t_u)
    f_pred = self.net_f(self.x_f, self.t_f)
    loss_u = torch.mean((self.u - u_pred) ** 2)
    loss_f = torch.mean(f_pred ** 2)
    return loss_u + loss_f
问题提出

为什么这样定义损失函数呢?

解答
  • x_f 和 t_f:它们是 X_f_train 的两列,而 X_f_train 是在 x,t 范围内随机生成的数据点。
  • 关键点 1
    1. 为什么随机点也能拿来做预测?
    2. 怎么判别预测的 u 结果准确呢?毕竟没有对应的 u。

这里依靠 u_pred = self.net_u(self.x_u, self.t_u)loss_u,x,t 都有对应的 u,通过让模型学习把 loss 降到最小,预测更加准确,也就有了可以和随机生成点对应的 u。

  • 关键点 2:为了让模型能够学习这个物理方程,引入了 loss_f,让模型学习把 loss_f 降到最小,使参数能够学习到物理方程的规律,两者相加(loss_u + loss_f)也就构成了 PINN。

片段 2

def net_f(self, x, t):
    u = self.net_u(x, t)
    u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
    f = u_t + u * u_x - self.nu * u_xx
    return f  # 方程的残差
解释
  • net_f:计算方程的残差。
  • 步骤
    1. 计算 u_t:对 u 关于 t 的梯度。
    2. 计算 u_x:对 u 关于 x 的梯度。
    3. 计算 u_xx:对 u_x 关于 x 的梯度。
    4. 计算残差 fu_t + u * u_x - self.nu * u_xx
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值