Table of Contents
关于n维球坐标转换,网络上有公式,但是几乎没找到实现代码.在此将自己在keras下用lambda函数写球坐标转换层的实现过程记录下来,并且记下来一些中间遇到的坑.
n维球坐标转换公式
参考下图:
n维球坐标转换的代码实现
具体环境:在keras (以tensorflow为backend)中,使用lambda层调用自己写的坐标转换函数.这里注意,keras中除了使用lambda,也可以参考此教程编写你自己的层.如果该层没有要参与训练的参数,就可以用lambda来写,比较简单;反之,要参考教程中的写法,本文暂时不涉及.
坐标变换函数代码:
def sphere_trans(inputs):
# inputs, b = inputall
# print(inputs)
for t in range(inputs.shape[1]):
# 将距离r放在输出tensor的第一个位置,后续都是变换得到的角度
if t ==0:
va = tf.norm(inputs[:,], axis = 1, ord='euclidean', keep_dims = True)
else:
c1 = tf.norm(inputs[:,t:], axis = 1, ord='euclidean', keep_dims = True)
# c2 = inputs[:,t]
# c2 = tf.identity(inputs[:,t-1], name=None)
c2 = tf.expand_dims(last_itm, -1)
va = tf.divide(c1, c2, name=None)
va = tf.atan(va)
last_itm = inputs[:,t]
part1 = inputs[:,:t]
part2 = inputs[:,t + 1:]
# 因为tensorflow中不能对tensor进行指定位置元素进行修改,
# 所以本文麻烦一点通过contat函数实现这个功能
inputs = tf.concat([part1, va, part2], axis=1)
return inputs
lambda函数搭建层的代码:
m = Lambda(sphere_trans, name='sphere_trans')(a)
model = Model(inputs=[a], outputs=[m])
遇到问题
- tensorflow中对tensor操作有些麻烦,不能对特定位置元素进行修改,只能统一来替换.在此我们使用了concat函数来实现.
- tensorflow中编写层时,要获取输入tensor的某个位置或某一段元素,一定要在None维加上冒号,即"tf.norm(inputs[:,]"中的第一个冒号.不加的话,模型搭建过程可能不会出错,但是模型在测试时会报错.这一点小问题之前耽误了我很久.