接上篇,本篇将继续结合算法结构进行代码实现。
首先,补充一下trajGRU继承的BaseConvRNN类。这个类用来计算状态转移过程中特征的宽、高、通道数。这些计算方法虽然在trajGRU中没有直接用到,但是可以辅助理解整个状态转移过程。
class BaseConvRNN(nn.Module):
def __init__(self,num_filter,b_h_w,h2h_kernel=(3,3),h2h_dilate=(1,1),i2h_kernel=(3,3),
i2h_stride=(1,1),i2h_pad=(1,1),i2h_dilate=(1,1),
act_type=torch.tanh,
prefix='BaseConvRNN'):
super(BaseConvRNN,self).__init__()
self._prefix = prefix
self._num_filter = num_filter
self._h2h_kernel = h2h_kernel
assert (self._h2h_kernel[0] % 2 == 1) and (self._h2h_kernel[1] % 2 == 1),print("Only support odd number, get h2h_kernel= %s" % str(h2h_kernel))
self._h2h_pad = (h2h_dilate[0] * (h2h_kernel[0]-1) // 2,
h2h_dilate[1] * (h2h_kernel[1]-1) // 2)
self._h2h_dilate = h2h_dilate
self._i2h_kernel = i2h_kernel
self._i2h_stride = i2h_stride
self._i2h_pad = i2h_pad
self._i2h_dilate = i2h_dilate
self._act_type = act_type
assert len(b_h_w) == 3
i2h_dilate_ksize_h = 1 + (self._i2h_kernel[0] - 1)*self._i2h_dilate[0]
i2h_dilate_ksize_w = 1 + (self._i2h_kernel[1] - 1)*self._i2h_dilate[1]
self._batch_size ,self._height,self._width = b_h_w
self._state_height = (self._height + 2 * self._i2h_pad[0] - i2h_dilate_ksize_h)// self._i2h_stride[0] + 1
self._state_width = (self._width + 2 * self._i2h_pad[1] - i2h_dilate_ksize_w) // self._i2h_stride[1] + 1
self._curr_states = None
self._counter = 0
需要注意的是在由dilate(空洞卷积)存在时即大于1时,输出特征尺寸的计算方式与之前不同。等于1时,是常用的计算方式
当空洞卷积设置大于1时,卷积核的尺寸会发生变化,如设置的是3,由于空洞卷积的存在,会变成
卷积核的尺寸会变大,但是数据量没有发生变化,依然是3*3*channel。
trajGRU代码分步实现
接上一篇文章,trajGRU算法结构可以分为5个部分,下文将按照这5个部分逐一说明代码实现过程以及其对应的算法。
第一部分:input2hidden
self.izh = nn.Conv2d(in_channels,num_filter*3,self._i2h_kernel,self._i2h_stride,self._i2h_pad,self._i2h_dilate)
输出的通道数乘3是因为有3个时次的状态同时存在,即前一时刻状态t-1,t时刻状态和下一时刻状态t+1。
第二部分:input2flow
self.i2f_conv1 = nn.Conv2d(in_channels=input_channel,
out_channels=32,
kernel_size=(5,5),
stride=1,
padding=(2,2),
dilation=(1,1))
输出通道数是32,这个跟论文中提到的配置一致。
第三部分:hidden2flow
self.h2f_conv1 = nn.Conv2d(in_channels=self._num_filter,
out_channels=32,
kernel_size=(5,5),
stride=1,
padding=(2,2),
dilation=(1,1))
同样的,这里的输出通道数也是设置为32,和论文中保持一致。
第四部分:generate flow
self.flows_conv = nn.Conv2d(in_channels=32,
out_channels=self._L*2,
kernel_size=(5,5),
stride=1,
padding=(2,2))
这里的结构是最重要的traj部分,也是网络能够捕捉时序运动的关键。
输入通道数32和输出通道数L*2与论文中的设置保持一致。这里的L既是论文中提到的Link连接的数量。通过这个link结构,使得网络相比目前的一些卷积网络更好的捕捉贴近真实情况的物体运动规律。这是因为link结构相当于一个可以通过反馈学习不断学习参数的光流层,而图像或视频中的运动在计算机视觉领域都是以光流呈现的。这也变相的解释了为什么输出的通道数是L*2,光流法中得到的光流矢量可以分解成x方向光流和y方向光流,所以是要乘以2,一个代表x方向的光流,一个代表y方向的光流。
第五部分 hidden2hidden
self.ret = nn.Conv2d(in_channels=self._num_filter*self._L,
out_channels=self._num_filter*3,
kernel_size=(1,1),
stride=1)
对于输入通道数,因为有L个连接,且每个连接的通道数都相同,所以是num_filter*L,又因为同时有3个隐藏状态,所以输出通道数是num_filter*3。
到此,网络的基本结构已经分析完成,下面将介绍基于这些基本结构的网络完整结构以及网络训练部分的代码。
补充一下wrap函数方法,主要就是一个双线性插值
def wrap(input,flow):
B,C,H,W = input.size()
# mesh grid
xx = torch.arange(0,W).view(1,-1).repeat(H,1).cuda()
yy = torch.arange(0,H).view(-1,1).repeat(1,W).cuda()
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
grid = torch.cat((xx,yy),1).float()
vgrid = grid + flow
vgrid[:,0,:,:] = torch.sub(torch.div(torch.mul(2,vgrid[:,0,:,:].clone().detach()),torch.sub(W,1)),1.0)
vgrid[:,1,:,:] = torch.sub(torch.div(torch.mul(2,vgrid[:,1,:,:].clone().detach()),torch.sub(H,1)),1.0)
# scale grid to [-1,1]
#vgrid[:,0,:,:] = 2.0 * vgrid[:,0,:,:].clone() / max(W-1,1) - 1.0
#vgrid[:,1,:,:] = 2.0 * vgrid[:,1,:,:].clone() / max(H-1,1) - 1.0
# N C H W
vgrid = vgrid.permute(0,2,3,1) # N*H*W*C
output = F.grid_sample(input,vgrid)
#output = bilinear_interpolate_torch_2D(input,vgrid)
#print(output.shape)
return output
trajGRU的完整代码
class TrajGRU(BaseConvRNN):
# b_h_w: input feature map size
def __init__(self, input_channel, num_filter,b_h_w,zoneout=0.2,L=5,
i2h_kernel=(3,3), i2h_stride=(1,1), i2h_pad=(1,1),
h2h_kernel=(5,5), h2h_dilate=(1,1), prefix='BaseConvRNN'):
super(TrajGRU,self).__init__(num_filter=num_filter,
b_h_w=b_h_w,
h2h_kernel=h2h_kernel,
h2h_dilate=h2h_dilate,
i2h_kernel=i2h_kernel,
i2h_stride=i2h_stride,
i2h_pad=i2h_pad,
prefix=prefix)
self._L = L
self._zoneout = zoneout
#self._act_type = F.leaky_relu()
# *3 according to 3 hidden states.
self.i2h = nn.Conv2d(in_channels=input_channel,
out_channels=self._num_filter*3,
kernel_size=self._i2h_kernel,
stride=self._i2h_stride,
padding=self._i2h_pad,
dilation=self._i2h_dilate)
# inputs to flow
self.i2f_conv1 = nn.Conv2d(in_channels=input_channel,
out_channels=32,
kernel_size=(5,5),
stride=1,
padding=(2,2),
dilation=(1,1))
# hidden to flow
self.h2f_conv1 = nn.Conv2d(in_channels=self._num_filter,
out_channels=32,
kernel_size=(5,5),
stride=1,
padding=(2,2),
dilation=(1,1))
# generate flow
self.flows_conv = nn.Conv2d(in_channels=32,
out_channels=self._L*2,
kernel_size=(5,5),
stride=1,
padding=(2,2))
# hh,hz,hr,1*1 ks
self.ret = nn.Conv2d(in_channels=self._num_filter*self._L,
out_channels=self._num_filter*3,
kernel_size=(1,1),
stride=1)
# inputs: B C H W
# flow comes from current inputs and forward states
def _flow_generator(self,inputs,states):
if inputs is not None:
i2f_conv1 = self.i2f_conv1(inputs)
else:
i2f_conv1 = None
h2f_conv1 = self.h2f_conv1(states)
f_conv1 = i2f_conv1 + h2f_conv1 if i2f_conv1 is not None else h2f_conv1
f_conv1 = F.leaky_relu(f_conv1,0.2,inplace=True)
flows = self.flows_conv(f_conv1)
# channels L*2 ,split according to 2
# get L flow maps each have 2 channels
flows = torch.split(flows,2,dim=1)
return flows
# inputs states
# inputs: S B C H W
def forward(self,inputs=None,states=None,seq_len=5):
if states is None:
states = torch.zeros((inputs.size(1),self._num_filter,self._state_height,
self._state_width),dtype= torch.float).cuda()
if inputs is not None:
S,B,C,H,W = inputs.size()
i2h = self.i2h(torch.reshape(inputs,(-1,C,H,W)))
i2h = torch.reshape(i2h,(S,B,i2h.size(1),i2h.size(2),i2h.size(3)))
i2h_slice = torch.split(i2h,self._num_filter,dim=2)
else:
i2h_slice = None
prev_h = states
outputs = []
for i in range(seq_len):
if inputs is not None:
flows = self._flow_generator(inputs[i,...],prev_h)
else:
flows = self._flow_generator(None,prev_h)
wrapped_data = []
for j in range(len(flows)):
flow = flows[j]
wrapped_data.append(wrap(prev_h,-flow))
wrapped_data = torch.cat(wrapped_data,dim=1)
h2h = self.ret(wrapped_data)
h2h_slice = torch.split(h2h,self._num_filter,dim=1)
if i2h_slice is not None:
reset_gate = torch.sigmoid(i2h_slice[0][i,...]+h2h_slice[0])
update_gate = torch.sigmoid(i2h_slice[1][i,...]+h2h_slice[1])
new_mem = F.leaky_relu(i2h_slice[2][i, ...] + reset_gate * h2h_slice[2],0.2,inplace=True)
else:
reset_gate = torch.sigmoid(h2h_slice[0])
update_gate = torch.sigmoid(h2h_slice[1])
new_mem = F.leaky_relu(reset_gate * h2h_slice[2],0.2,inplace=True)
next_h = update_gate * prev_h + (1-update_gate) * new_mem
if self._zoneout > 0.0:
mask = F.dropout2d(torch.zeros_like(prev_h),p=self._zoneout)
next_h = torch.where(mask,next_h,prev_h)
outputs.append(next_h)
prev_h = next_h
return torch.stack(outputs),next_h
在forward方法中,需要注意的是输入序列长度的设置,如果你旨在通过当前1小时的数据外推未来3小时的数据,需要根据手头的数据时次设置输入序列的长度。
比如:训练数据是全国拼图,10分钟一次,在encoder阶段设置输入序列长度是1*60/10 = 6,在decoder阶段设置输入序列长度是3*60/10 = 18
Encoder-Forecaster结构
Encoder和Forecaster是对称的结构,在Encoder阶段完成特征学习和提取,在Forecaster阶段完成图像重建和外推。
Encoder参数设置
encoder_params = [
[
OrderedDict({'conv1_leaky_1':[1,8,7,5,1]}),# inputchannel,outputchannel,ks,stride,padding
OrderedDict({'conv2_leaky_1':[64,192,5,3,1]}),
OrderedDict({'conv3_leaky_1':[192,192,3,2,1]}),
],
[
TrajGRU(input_channel=8,num_filter=64,b_h_w=(batch_size,int(config['Sets']['ER11']),int(config['Sets']['ER12'])),zoneout=0.0,L=13,
i2h_kernel=(3,3),i2h_stride=(1,1),i2h_pad=(1,1),
h2h_kernel=(5,5),h2h_dilate=(1,1)),
TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size,int(config['Sets']['ER21']),int(config['Sets']['ER22'])), zoneout=0.0, L=13,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1)),
TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size,int(config['Sets']['ER31']),int(config['Sets']['ER32'])), zoneout=0.0, L=9,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(3, 3), h2h_dilate=(1, 1))
]
]
encoder结构包含3个卷积层和3个TrajGRU层,通过卷积层进行downsample,通过TrajGRU进行特征提取和光流学习。
Forecaster参数设置
forecaster_params = [
[
OrderedDict({'deconv1_leaky_1': [192, 192, 4, 2, 1]}),
OrderedDict({'deconv2_leaky_1': [192, 64, 5, 3, 1]}),
OrderedDict({
'deconv3_leaky_1': [64, 8, 7, 5, 1],
'conv3_leaky_2': [8, 8, 3, 1, 1],
'conv3_3': [8, 1, 1, 1, 0]
}),
],
[
TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size,int(config['Sets']['ER31']),int(config['Sets']['ER32'])), zoneout=0.0, L=13,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(3, 3), h2h_dilate=(1, 1),
),
TrajGRU(input_channel=192, num_filter=192, b_h_w=(batch_size,int(config['Sets']['ER21']),int(config['Sets']['ER22'])), zoneout=0.0, L=13,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1),
),
TrajGRU(input_channel=64, num_filter=64, b_h_w=(batch_size,int(config['Sets']['ER11']),int(config['Sets']['ER12'])), zoneout=0.0, L=9,
i2h_kernel=(3, 3), i2h_stride=(1, 1), i2h_pad=(1, 1),
h2h_kernel=(5, 5), h2h_dilate=(1, 1),
)
]
]
Forcaster和Encoder是对称的结构,同样包含3个卷积层和3个TrajGRU层,通过卷积层进行upsample,通过TrajGRU层进行外推。
Encoder结构
class Encoder(nn.Module):
def __init__(self,subnets,rnns):
super().__init__()
assert len(subnets) == len(rnns)
self.blocks = len(subnets)
for index,(params,rnn) in enumerate(zip(subnets,rnns),1):
setattr(self,'stage'+str(index),make_layers(params))
setattr(self,'rnn'+str(index),rnn)
def forward_by_stage(self,input,subnet,rnn):
seq_number,batch_size,input_channel,height,width = input.size()
input = torch.reshape(input,(-1,input_channel,height,width))
input = subnet(input)
input = torch.reshape(input,(seq_number,batch_size,input.size(1),input.size(2),input.size(3)))
outputs_stage,state_stage = rnn(input,None)
return outputs_stage,state_stage
## input: 5D T*B*C*H*W
def forward(self,input):
hidden_states = []
logging.debug(input.size())
for i in range(1,self.blocks+1):
input,state_stage = self.forward_by_stage(input,getattr(self,'stage'+str(i)),getattr(self,'rnn'+str(i)))
#print(f'stage {i} shape {input.shape} rnn {i} shape {state_stage.shape}')
hidden_states.append(state_stage)
return tuple(hidden_states)
Forecaster结构
class Forecaster(nn.Module):
def __init__(self,subnets,rnns):
super().__init__()
assert len(subnets) == len(rnns)
self.blocks = len(subnets)
## use transposeConv to enlarge outputs
#self.transposeConv2d = nn.ConvTranspose2d(in_channels=1, out_channels=1,kernel_size=5, stride=(5,7),padding=1)
for index,(params,rnn) in enumerate(zip(subnets,rnns)):
setattr(self,'rnn'+str(self.blocks-index),rnn)
setattr(self,'stage'+str(self.blocks-index),make_layers(params))
#self.conv2d = nn.Conv2d(in_channels=1,out_channels=1,kernel_size=3,padding=1)
def forward_by_stage(self,input,state,subnet,rnn):
input,state_stage = rnn(input,state,seq_len=20) #### 20 frames要注意这里的序列长度设置
seq_number,batch_size,input_channel,height,width = input.size()
input = torch.reshape(input,(-1,input_channel,height,width))
input = subnet(input)
input = torch.reshape(input,(seq_number,batch_size,input.size(1),input.size(2),input.size(3)))
return input
def forward_transpose(self,input):
seq_number,batch_size,input_channel,height,width = input.size()
input = torch.reshape(input,(-1,input_channel,height,width))
## enlarge outputs
## use conv to delete noise
#input = self.transposeConv2d(input)
#input = self.conv2d(input)
input = torch.reshape(input,(seq_number,batch_size,input.size(1),input.size(2),input.size(3)))
return input
def forward(self,hidden_states):
input = self.forward_by_stage(None,hidden_states[-1],getattr(self,'stage3'),getattr(self,'rnn3'))
for i in list(range(1,self.blocks))[::-1]:
input = self.forward_by_stage(input,hidden_states[i-1],getattr(self,'stage'+str(i)),getattr(self,'rnn'+str(i)))
#print(f'forcaster {i} shape {input.shape}')
input = self.forward_transpose(input)
#print(input.shape)
return input
Encoder-Forecaster
class EF(nn.Module):
def __init__(self,encoder,forecaster):
super().__init__()
self.encoder = encoder
self.forecaster = forecaster
def forward(self,input):
state = self.encoder(input)
output = self.forecaster(state)
return output
Training
def train(*args):
## use mse + mae loss ..
model,epochs,LR,dataset = args
criterion = nn.MSELoss()
l1 = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(),lr=LR)
mult_step_scheduler = lr_scheduler.MultiStepLR(optimizer,milestones=[30,60],gamma=0.1)
# resume learning......
#model.load_state_dict(torch.load('MSEModels/ckpt-maemse-80-38.089400.pth'))
model.train()
model.cuda()
for i in range(0,epochs):
#print(f'=====training epoch {i+1}======')
logger.info(f'=====training epoch {i+1}======')
epoch_start = time.time()
train_loss = 0
prefetcher = data_prefetcher(dataset)
x,y = prefetcher.next()
iteration = 0
while x is not None:
iteration += 1
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
outputs = model(x)
#loss = criterion(outputs,y) + l1(outputs,y)
loss = criterion(outputs,y)
iterationLoss = loss.item()
train_loss += iterationLoss
loss.backward()
optimizer.step()
#if iteration % 20 == 0:
#logger.info(f'epoch {i+1} index {iteration} mse_mae_loss {iterationLoss}')
logger.info(f'epoch {i+1} index {iteration} mse_loss {iterationLoss}')
x,y = prefetcher.next()
mult_step_scheduler.step()
epoch_end = time.time()
#logger.info(f'epoch {i+1} final mse_mae_loss {round(train_loss/iteration,3)} epoch uses {(epoch_end-epoch_start)//60}minutes')
logger.info(f'epoch {i+1} final mse {round(train_loss/iteration,3)} epoch uses {(epoch_end-epoch_start)//60}minutes')
torch.save(model.state_dict(),'checkPoints/ckpt-mse-%d-%f.pth'%(i+1,round(train_loss/iteration,3)))
#torch.save(model.module.state_dict(),'ckpt-%d-%f.pth'%(i+1,round(train_loss/500,4)))
可以通过设置不同的loss函数,比较外推的效果。在下一篇文章中会讨论一下模型性能评估指标以及外推效果展示。