作用:改写了三个类,为的是在用geometric里面的pairdata的时候,使用一个类,传入pairdata,求出pairdata中的A和B的法向量啊,以及对坐标归一化啊之类的,
class Center(object):
r"""Centers node positions around the origin."""
def __call__(self, data):
data.pos_A = data.pos_A - data.pos_A.mean(dim=-2, keepdim=True)
data.pos_B = data.pos_B - data.pos_B.mean(dim=-2, keepdim=True)
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
class Delaunay(object):
r"""Computes the delaunay triangulation of a set of points."""
def __call__(self, data):
if data.pos_A.size(0) < 2:
data.edge_index = torch.tensor([], dtype=torch.long,
device=data.pos_A.device).view(2, 0)
if data.pos_A.size(0) == 2:
data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
device=data.pos_A.device)
elif data.pos_A.size(0) == 3:
data.face_A = torch.tensor([[0], [1], [2]], dtype=torch.long,
device=data.pos_A.device)
if data.pos_A.size(0) > 3:
pos_A = data.pos_A.cpu().numpy()
tri = scipy.spatial.Delaunay(pos_A, qhull_options='QJ')
face_A = torch.from_numpy(tri.simplices)
data.face_A = face_A.t().contiguous().to(data.pos_A.device, torch.long)
#----------chain B------------------------------
if data.pos_B.size(0) < 2:
data.edge_index = torch.tensor([], dtype=torch.long,
device=data.pos_B.device).view(2, 0)
if data.pos_B.size(0) == 2:
data.edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long,
device=data.pos_B.device)
elif data.pos_B.size(0) == 3:
data.face_B = torch.tensor([[0], [1], [2]], dtype=torch.long,
device=data.pos_B.device)
if data.pos_B.size(0) > 3:
pos_B = data.pos_B.cpu().numpy()
tri = scipy.spatial.Delaunay(pos_B, qhull_options='QJ')
face_B = torch.from_numpy(tri.simplices)
data.face_B = face_B.t().contiguous().to(data.pos_B.device, torch.long)
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
class GenerateMeshNormals(object):
r"""Generate normal vectors for each mesh node based on neighboring
faces."""
def __call__(self, data):
assert 'face_A' in data
pos_A, face_A = data.pos_A, data.face_A
vec1 = pos_A[face_A[1]] - pos_A[face_A[0]]
vec2 = pos_A[face_A[2]] - pos_A[face_A[0]]
face_norm_A = F.normalize(vec1.cross(vec2), p=2, dim=-1) # [F, 3]
idx = torch.cat([face_A[0], face_A[1], face_A[2]], dim=0)
face_norm_A = face_norm_A.repeat(3, 1)
norm_A = scatter_add(face_norm_A, idx, dim=0, dim_size=pos_A.size(0))
norm_A = F.normalize(norm_A, p=2, dim=-1) # [N, 3]
data.norm_A = norm_A
#----------------------chain B----------------------
assert 'face_B' in data
pos_B, face_B = data.pos_B, data.face_B
vec1 = pos_B[face_B[1]] - pos_B[face_B[0]]
vec2 = pos_B[face_B[2]] - pos_B[face_B[0]]
face_norm_B = F.normalize(vec1.cross(vec2), p=2, dim=-1) # [F, 3]
idx = torch.cat([face_B[0], face_B[1], face_B[2]], dim=0)
face_norm_B = face_norm_B.repeat(3, 1)
norm_B = scatter_add(face_norm_B, idx, dim=0, dim_size=pos_B.size(0))
norm_B = F.normalize(norm_B, p=2, dim=-1) # [N, 3]
data.norm_B = norm_B
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)