nlf_train/NeuralLocalizerFields/nlf/tf/tfu.py
# NHWC <-> NCHW conversions
def nhwc_to_nchw(x):
if isinstance(x, tf.Tensor):
ndims = x.shape.rank
if ndims == 3:
return tf.transpose(x, [2, 0, 1])
elif ndims == 4:
return tf.transpose(x, [0, 3, 1, 2])
elif ndims == 2:
return x
else:
raise Exception()
if isinstance(x, list) or isinstance(x, tuple) or x.ndim == 1:
if len(x) == 3:
return type(x)((x[2], x[0], x[1]))
elif len(x) == 4:
return type(x)((x[0], x[3], x[1], x[2]))
elif len(x) == 2:
return x
raise Exception()
if x.ndim == 3:
return np.transpose(x, [2, 0, 1])
elif x.ndim == 4:
return np.transpose(x, [0, 3, 1, 2])
elif x.ndim == 2:
return x
else:
raise Exception()
def nchw_to_nhwc(x):
if isinstance(x, tf.Tensor):
ndims = x.shape.rank
if ndims == 3:
return tf.transpose(x, [1, 2, 0])
elif ndims == 4:
return tf.transpose(x, [0, 2, 3, 1])
elif ndims == 2:
return x
else:
raise Exception()
if isinstance(x, list) or isinstance(x, tuple) or x.ndim == 1:
if len(x) == 3:
return type(x)((x[1], x[2], x[0]))
elif len(x) == 4:
return type(x)((x[0], x[2], x[3], x[1]))
elif len(x) == 2:
return x
else:
raise Exception()
if x.ndim == 3:
return np.transpose(x, [1, 2, 0])
elif x.ndim == 4:
return np.transpose(x, [0, 2, 3, 1])
elif x.ndim == 2:
return x
else:
raise Exception()