PhyCRNet
代码解析
1、完整代码
'''PhyCRNet for solving spatiotemporal PDEs'''
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
import time
import os
from torch.nn.utils import weight_norm
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.manual_seed(66)
np.random.seed(66)
torch.set_default_dtype(torch.float32)
lapl_op = [[[[ 0, 0, -1/12, 0, 0],
[ 0, 0, 4/3, 0, 0],
[-1/12, 4/3, -5, 4/3, -1/12],
[ 0, 0, 4/3, 0, 0],
[ 0, 0, -1/12, 0, 0]]]]
partial_y = [[[[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1/12, -8/12, 0, 8/12, -1/12],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]]
partial_x = [[[[0, 0, 1/12, 0, 0],
[0, 0, -8/12, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 8/12, 0, 0],
[0, 0, -1/12, 0, 0]]]]
def initialize_weights(module):
''' starting from small initialized parameters '''
if isinstance(module, nn.Conv2d):
c = 0.1
module.weight.data.uniform_(-c*np.sqrt(1 / np.prod(module.weight.shape[:-1])),
c*np.sqrt(1 / np.prod(module.weight.shape[:-1])))
elif isinstance(module, nn.Linear):
module.bias.data.zero_()
class ConvLSTMCell(nn.Module):
''' Convolutional LSTM '''
def __init__(self, input_channels, hidden_channels, input_kernel_size,
input_stride, input_padding):
super(ConvLSTMCell, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.hidden_kernel_size = 3
self.input_kernel_size = input_kernel_size
self.input_stride = input_stride
self.input_padding = input_padding
self.num_features = 4
self.padding = int((self.hidden_kernel_size - 1) / 2)
self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels,
self.input_kernel_size, self.input_stride, self.input_padding,
bias=True, padding_mode='circular')
self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels,
self.hidden_kernel_size, 1, padding=1, bias=False,
padding_mode='circular')
self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels,
self.input_kernel_size, self.input_stride, self.input_padding,
bias=True, padding_mode='circular')
self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels,
self.hidden_kernel_size, 1, padding=1, bias=False,
padding_mode='circular')
self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels,
self.input_kernel_size, self.input_stride, self.input_padding,
bias=True, padding_mode='circular')
self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels,
self.hidden_kernel_size, 1, padding=1, bias=False,
padding_mode='circular')
self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels,
self.input_kernel_size, self.input_stride, self.input_padding,
bias=True, padding_mode='circular')
self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels,
self.hidden_kernel_size, 1, padding=1, bias=False,
padding_mode='circular')
nn.init.zeros_(self.Wxi.bias)
nn.init.zeros_(self.Wxf.bias)
nn.init.zeros_(self.Wxc.bias)
self.Wxo.bias.data.fill_(1.0)
def forward(self, x, h, c):
ci = torch.sigmoid(self.Wxi(x) + self.Whi(h))
cf = torch.sigmoid(self.Wxf(x) + self.Whf(h))
cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h))
co = torch.sigmoid(self.Wxo(x) + self.Who(h))
ch = co * torch.tanh(cc)
return ch, cc
def init_hidden_tensor(self, prev_state):
return (Variable(prev_state[0]).cuda(), Variable(prev_state[1]).cuda())
class encoder_block(nn.Module):
''' encoder with CNN '''
def __init__(self, input_channels, hidden_channels, input_kernel_size,
input_stride, input_padding):
super(encoder_block, self).__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.input_kernel_size = input_kernel_size
self.input_stride = input_stride
self.input_padding = input_padding
self.conv = weight_norm(nn.Conv2d(self.input_channels,
self.hidden_channels, self.input_kernel_size, self.input_stride,
self.input_padding, bias=True, padding_mode='circular'))
self.act = nn.ReLU()
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.act(self.conv(x))
class PhyCRNet(nn.Module):
''' physics-informed convolutional-recurrent neural networks '''
def __init__(self, input_channels, hidden_channels,
input_kernel_size, input_stride, input_padding, dt,
num_layers, upscale_factor, step=1, effective_step=[1]):
super(PhyCRNet, self).__init__()
self.input_channels = [input_channels] + hidden_channels
self.hidden_channels = hidden_channels
self.input_kernel_size = input_kernel_size
self.input_stride = input_stride
self.input_padding = input_padding
self.step = step
self.effective_step = effective_step
self._all_layers = []
self.dt = dt
self.upscale_factor = upscale_factor
self.num_encoder = num_layers[0]
self.num_convlstm = num_layers[1]
for i in range(self.num_encoder):
name = 'encoder{}'.format(i)
cell = encoder_block(
input_channels = self.input_channels[i],
hidden_channels = self.hidden_channels[i],
input_kernel_size = self.input_kernel_size[i],
input_stride = self.input_stride[i],
input_padding = self.input_padding[i])
setattr(self, name, cell)
self._all_layers.append(cell)
for i in range(self.num_encoder, self.num_encoder + self.num_convlstm):
name = 'convlstm{}'.format(i)
cell = ConvLSTMCell(
input_channels = self.input_channels[i],
hidden_channels = self.hidden_channels[i],
input_kernel_size = self.input_kernel_size[i],
input_s