C++如何切3D矩阵、拼接3D矩阵
{
if (mpf_mumap == NULL) {
std::cout << "Mumap is not loaded!" << std::endl;
return false;
}
//cut mumap into patches
int i_x = shapes[0];
int i_y = shapes[1];
int i_z = shapes[2];
int i_patch_x = 64;
int i_patch_y = 64;
int i_patch_z = 64;
int i_patch_num_x = i_x / i_patch_x;
int i_patch_num_y = i_y / i_patch_y;
int i_patch_num_z = i_z / i_patch_z;
int i_patch_size = i_patch_x * i_patch_y * i_patch_z;
float *mpf_patch_mumap = NULL;
VectorAlloc(mpf_patch_mumap, i_patch_size);
float *mpf_patch_mask = NULL;
VectorAlloc(mpf_patch_mask, i_patch_size);
// infer the mask for each patch
for (int i = 0; i < i_patch_num_x; i++) {
for (int j = 0; j < i_patch_num_y; j++) {
for (int k = 0; k < i_patch_num_z; k++) {
int i_patch_start_x = i * i_patch_x;
int i_patch_start_y = j * i_patch_y;
int i_patch_start_z = k * i_patch_z;
int i_patch_end_x = i_patch_start_x + i_patch_x;
int i_patch_end_y = i_patch_start_y + i_patch_y;
int i_patch_end_z = i_patch_start_z + i_patch_z;
int i_patch_idx = i * i_patch_num_y * i_patch_num_z + j * i_patch_num_z + k;
// copy the mumap to mpf_patch_mumap for each pixel
for (int ii = i_patch_start_x; ii < i_patch_end_x; ii++) {
for (int jj = i_patch_start_y; jj < i_patch_end_y; jj++) {
for (int kk = i_patch_start_z; kk < i_patch_end_z; kk++) {
// int i_mumap_idx = ii * i_y * i_z + jj * i_z + kk; // from copilot
int i_mumap_idx = ii * i_x + jj + i_x * i_y * kk; //
int i_patch_idx = (ii - i_patch_start_x) * i_patch_x + (jj - i_patch_start_y)
+ i_patch_x * i_patch_y * (kk - i_patch_start_z);
mpf_patch_mumap[i_patch_idx] = mpf_mumap[i_mumap_idx];
}
}
}
//infer the mask
at::Tensor tensor_input = torch::empty({1, 1, i_patch_x, i_patch_y, i_patch_z}, torch::kFloat32);
memcpy(tensor_input.data_ptr(), mpf_patch_mumap, sizeof(float) * i_patch_size);
at::Tensor tensor_output = m_module.forward({tensor_input}).toTensor();
// copy the network output to mpf_mask
tensor_output = tensor_output.to(torch::kCPU);
tensor_output = tensor_output.squeeze();
memcpy(mpf_patch_mask, tensor_output.data_ptr(), sizeof(float) * i_patch_size);
//update the mpf_mask by mpf_patch_mask
for (int ii = i_patch_start_x; ii < i_patch_end_x; ii++) {
for (int jj = i_patch_start_y; jj < i_patch_end_y; jj++) {
for (int kk = i_patch_start_z; kk < i_patch_end_z; kk++) {
int i_mask_idx = ii * i_x + jj + i_x * i_y * kk; //
int i_patch_idx = (ii - i_patch_start_x) * i_patch_x + (jj - i_patch_start_y)
+ i_patch_x * i_patch_y * (kk - i_patch_start_z);
mpf_mask[i_mask_idx] = mpf_patch_mask[i_patch_idx];
}
}
}
}
}
}
return false;
最核心的是这两个代码:
int i_mumap_idx = ii * i_x + jj + i_x * i_y * kk; //
int i_patch_idx = (ii - i_patch_start_x) * i_patch_x + (jj - i_patch_start_y)
+ i_patch_x * i_patch_y * (kk - i_patch_start_z);
mpf_patch_mumap[i_patch_idx] = mpf_mumap[i_mumap_idx];
对于每个小patch,他的start index就是减去i_patch_start_x。这一点非常的nice