原始的PINNs(Physics-Informed Neural Networks, PINNs)方法不具备求解一类方程的能力。当方程中的特征参数(如介电系数等)发生变化时需要重新训练,增加了求解时间。
本教程重点介绍基于MindElec套件的物理信息自解码器(Physics-Informed Auto-Decoder)增量训练方法,该方法可以快速求解同一类方程,极大减少重新训练的时间。
基于隐向量和神经网络的结合对一系列方程组进行预训练。与求解单个问题不同,预训练步骤中,神经网络的输入为采样点(X)与隐向量(Z)的融合,具体如下图所示:
针对新的方程组,对隐向量和神经网络进行增量训练,快速求解新问题。这里我们提供了两种增量训练模式:
- finetune_latent_with_model: 该方式同时更新隐向量和网络结构,只需要加载预训练的模型进行增量训练即可。
- finetune_latent_only: 如下图所示,该方式固定网络结构,在增量训练中只更新隐向量。
如图:
导入依赖
导入本教程依赖的模块与接口:
代码如下:
from mindelec.data import Dataset
from mindelec.geometry import Disk, Rectangle, TimeDomain, GeometryWithTime
from mindelec.loss import Constraints
from mindelec.solver import Solver, LossAndTimeMonitor
from mindelec.common import L2
from mindelec.architecture import MultiScaleFCCell, MTLWeightedLossCell
from src import get_test_data, create_random_dataset
from src import MultiStepLR
from src import Maxwell2DMur
from src import PredictCallback
from src import visual_result
创建数据集
与点源麦克斯韦方程的方式一致,我们在矩形计算域进行5次均匀采样,即由控制方程所约束的矩形域和源区附近的内部点采样;由初始条件所约束的矩形域和源区附近的内部点采样;以及由边界条件所控制的矩形域边界采样。空间采样与时间采样数据组合构成了训练样本。
代码如下:
# src region
disk = Disk("src", disk_origin, disk_radius)
# no src region
rectangle = Rectangle("rect", coord_min, coord_max)
diff = rectangle - disk
# time info
time_interval = TimeDomain("time", 0.0, config["range_t"])
# geometry merge with time
no_src_region = GeometryWithTime(diff, time_interval)
no_src_region.set_name("no_src")
no_src_region.set_sampling_config(create_config_from_edict(no_src_sampling_config))
src_region = GeometryWithTime(disk, time_interval)
src_region.set_name("src")
src_region.set_sampling_config(create_config_from_edict(src_sampling_config))
boundary = GeometryWithTime(rectangle, time_interval)
boundary.set_name("bc")
boundary.set_sampling_config(create_config_from_edict(bc_sampling_config))
# final sampling fields
geom_dict = {src_region : ["domain", "IC"],
no_src_region : ["domain", "IC"],
boundary : ["BC"]}
MindElec提供了将不同的采样数据合并为统一训练数据集的Dataset接口。】
代码如下:
、
# create dataset for train
elec_train_dataset = create_random_dataset(config)
train_dataset = elec_