前提已有数据:比如2个点云文件,点云文件有4点
需求数据:
1.将对个点云cat拼接到一个 [ 2,4,3 ]->[ 2*4,3 ]
2.每个点有一个索引:[0,0,0,0,1,1,1,1]
参考地址:
https://developer.moduyun.com/column/detail/82643.html
import torch
from torch_geometric.data import Data, Batch
def switchupdata(batch_input):
temparry = []
for i in range(batch_input.shape[0]):
onetorch = batch_input[i].squeeze() # 2维度的数据
temparry.append(Data(onetorch)) # 转成List
print(onetorch.shape)
switch = Batch().from_data_list(temparry)
x = switch.x
batch = switch.batch
edge_index = switch.edge_index
return x,batch,edge_index
if __name__ == '__main__':
# 准备一个 batch 点云数据(3维)
data_tsdim3 = torch.tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3],
[4, 4, 4], [5, 5, 5], [6, 6, 6]]],
dtype=torch.float32)
xdata_G,batch_G,edge_index = switchupdata(data_tsdim3)
print(xdata_G)
print(xdata_G.shape)
print(batch_G)
print(edge_index)