PINN 入门笔记
前言
这段代码最重要的难点在于理解损失函数和物理方程残差的定义。看懂这里,你就入门了 PINN。
重点剖析
片段 1
def loss_func(self):
u_pred = self.net_u(self.x_u, self.t_u)
f_pred = self.net_f(self.x_f, self.t_f)
loss_u = torch.mean((self.u - u_pred) ** 2)
loss_f = torch.mean(f_pred ** 2)
return loss_u + loss_f
问题提出
为什么这样定义损失函数呢?
解答
- x_f 和 t_f:它们是
X_f_train
的两列,而X_f_train
是在 x,t 范围内随机生成的数据点。 - 关键点 1:
- 为什么随机点也能拿来做预测?
- 怎么判别预测的 u 结果准确呢?毕竟没有对应的 u。
这里依靠 u_pred = self.net_u(self.x_u, self.t_u)
和 loss_u
,x,t 都有对应的 u,通过让模型学习把 loss 降到最小,预测更加准确,也就有了可以和随机生成点对应的 u。
- 关键点 2:为了让模型能够学习这个物理方程,引入了
loss_f
,让模型学习把loss_f
降到最小,使参数能够学习到物理方程的规律,两者相加(loss_u + loss_f
)也就构成了 PINN。
片段 2
def net_f(self, x, t):
u = self.net_u(x, t)
u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True)[0]
u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0]
u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
f = u_t + u * u_x - self.nu * u_xx
return f # 方程的残差
解释
- net_f:计算方程的残差。
- 步骤:
- 计算
u_t
:对u
关于t
的梯度。 - 计算
u_x
:对u
关于x
的梯度。 - 计算
u_xx
:对u_x
关于x
的梯度。 - 计算残差
f
:u_t + u * u_x - self.nu * u_xx
。
- 计算