pairdata求法向量

作用:改写了三个类,为的是在用geometric里面的pairdata的时候,使用一个类,传入pairdata,求出pairdata中的A和B的法向量啊,以及对坐标归一化啊之类的,

class Center(object):
    r"""Centers node positions around the origin."""

    def __call__(self, data):
        data.pos_A = data.pos_A - data.pos_A.mean(dim=-2, keepdim=True)
        data.pos_B = data.pos_B - data.pos_B.mean(dim=-2, keepdim=True)
        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

class Delaunay(object):
    r"""Computes the delaunay triangulation of a set of points."""
    def __call__(self, data):
        if data.pos_A.size(0) < 2:
            data.edge_index = torch.tensor([], dtype=torch.long,
                                           device=data.pos_A.device).view(2, 0)
        if data.pos_A.size(0) == 2:
            data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
                                           device=data.pos_A.device)
        elif data.pos_A.size(0) == 3:
            data.face_A = torch.tensor([[0], [1], [2]], dtype=torch.long,
                                     device=data.pos_A.device)
        if data.pos_A.size(0) > 3:
            pos_A = data.pos_A.cpu().numpy()
            tri = scipy.spatial.Delaunay(pos_A, qhull_options='QJ')
            face_A = torch.from_numpy(tri.simplices)

            data.face_A = face_A.t().contiguous().to(data.pos_A.device, torch.long)
        #----------chain B------------------------------
        if data.pos_B.size(0) < 2:
            data.edge_index = torch.tensor([], dtype=torch.long,
                                           device=data.pos_B.device).view(2, 0)
        if data.pos_B.size(0) == 2:
            data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
                                           device=data.pos_B.device)
        elif data.pos_B.size(0) == 3:
            data.face_B = torch.tensor([[0], [1], [2]], dtype=torch.long,
                                     device=data.pos_B.device)
        if data.pos_B.size(0) > 3:
            pos_B = data.pos_B.cpu().numpy()
            tri = scipy.spatial.Delaunay(pos_B, qhull_options='QJ')
            face_B = torch.from_numpy(tri.simplices)

            data.face_B = face_B.t().contiguous().to(data.pos_B.device, torch.long)

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

class GenerateMeshNormals(object):
    r"""Generate normal vectors for each mesh node based on neighboring
    faces."""

    def __call__(self, data):
        assert 'face_A' in data
        pos_A, face_A = data.pos_A, data.face_A

        vec1 = pos_A[face_A[1]] - pos_A[face_A[0]]
        vec2 = pos_A[face_A[2]] - pos_A[face_A[0]]
        face_norm_A = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]

        idx = torch.cat([face_A[0], face_A[1], face_A[2]], dim=0)
        face_norm_A = face_norm_A.repeat(3, 1)

        norm_A = scatter_add(face_norm_A, idx, dim=0, dim_size=pos_A.size(0))
        norm_A = F.normalize(norm_A, p=2, dim=-1)  # [N, 3]

        data.norm_A = norm_A

        #----------------------chain B----------------------
        assert 'face_B' in data
        pos_B, face_B = data.pos_B, data.face_B

        vec1 = pos_B[face_B[1]] - pos_B[face_B[0]]
        vec2 = pos_B[face_B[2]] - pos_B[face_B[0]]
        face_norm_B = F.normalize(vec1.cross(vec2), p=2, dim=-1)  # [F, 3]

        idx = torch.cat([face_B[0], face_B[1], face_B[2]], dim=0)
        face_norm_B = face_norm_B.repeat(3, 1)

        norm_B = scatter_add(face_norm_B, idx, dim=0, dim_size=pos_B.size(0))
        norm_B = F.normalize(norm_B, p=2, dim=-1)  # [N, 3]

        data.norm_B = norm_B

        return data

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

这个人很懒,还没有设置昵称...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值