keras下使用lambda搭建n维球坐标转换层

Table of Contents

n维球坐标转换公式

n维球坐标转换的代码实现

遇到问题


关于n维球坐标转换,网络上有公式,但是几乎没找到实现代码.在此将自己在keras下用lambda函数写球坐标转换层的实现过程记录下来,并且记下来一些中间遇到的坑.

n维球坐标转换公式

参考下图:

n维球坐标转换

n维球坐标转换的代码实现

具体环境:在keras (以tensorflow为backend)中,使用lambda层调用自己写的坐标转换函数.这里注意,keras中除了使用lambda,也可以参考此教程编写你自己的层.如果该层没有要参与训练的参数,就可以用lambda来写,比较简单;反之,要参考教程中的写法,本文暂时不涉及.

坐标变换函数代码:

def sphere_trans(inputs):
    # inputs, b = inputall
    # print(inputs)

    for t in range(inputs.shape[1]):
        # 将距离r放在输出tensor的第一个位置,后续都是变换得到的角度
        if t ==0:
            va = tf.norm(inputs[:,], axis = 1, ord='euclidean', keep_dims = True)
        else:
            c1 = tf.norm(inputs[:,t:], axis = 1, ord='euclidean', keep_dims = True)

            # c2 = inputs[:,t]
            # c2 = tf.identity(inputs[:,t-1], name=None)
            c2 = tf.expand_dims(last_itm, -1)
            va = tf.divide(c1, c2, name=None)
            va = tf.atan(va)

        last_itm = inputs[:,t]
        part1 = inputs[:,:t]
        part2 = inputs[:,t + 1:]
        
        # 因为tensorflow中不能对tensor进行指定位置元素进行修改,
        # 所以本文麻烦一点通过contat函数实现这个功能
        inputs = tf.concat([part1, va, part2], axis=1)
    return inputs

lambda函数搭建层的代码:

m = Lambda(sphere_trans, name='sphere_trans')(a)
model = Model(inputs=[a], outputs=[m])

遇到问题

  1. tensorflow中对tensor操作有些麻烦,不能对特定位置元素进行修改,只能统一来替换.在此我们使用了concat函数来实现.
  2. tensorflow中编写层时,要获取输入tensor的某个位置或某一段元素,一定要在None维加上冒号,即"tf.norm(inputs[:,]"中的第一个冒号.不加的话,模型搭建过程可能不会出错,但是模型在测试时会报错.这一点小问题之前耽误了我很久.

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值