网络结构为,conv(3,3)+residual block(N个residual block层),+conv(3,3)+upsample block.
residual block:
def resBlock(x,channels=64,kernel_size=[3,3],scale=1):
tmp = slim.conv2d(x,channels,kernel_size,activation_fn=None)
tmp = tf.nn.relu(tmp)
tmp = slim.conv2d(tmp,channels,kernel_size,activation_fn=None)
tmp *= scale
return x + tmp
upsample block:
def upsample(x,scale=2,features=64,activation=tf.nn.relu):
assert scale in [2,3,4]
x = slim.conv2d(x,features,[3,3],activation_fn=activation)
if scale == 2:
ps_features = 3*(sca