remesh 重新网格化 网格平滑

目录

continuous-remeshing

Qremeshfy:

Blender-CWF-Remesher

LocalRemesher

pshuman中就有重训网格化算法


continuous-remeshing

https://github.com/Profactor/continuous-remeshing

Qremeshfy:

GitHub - ksami/QRemeshify: A Blender extension for an easy-to-use remesher that outputs good-quality quad topology

Blender-CWF-Remesher

https://github.com/AIGODLIKE/Blender-CWF-Remesher

LocalRemesher

GitHub - negdo/LocalRemesher: Blender addon for remeshing selected area with awarness of surrounding geometry.

pshuman中就有重训网格化算法

reconstruct.py

import argparse
import os
import pickle
import zipfile
from glob import glob

import kornia
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import trimesh
import math
import pytorch3d
import cv2
import copy
from icecream import ic
from omegaconf import OmegaConf
from PIL import Image
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm

from core.opt import MeshOptimizer
from core.remesh import calc_vertex_normals
from lib.dataset.mesh_util import apply_vertex_mask, keep_largest, part_removal, poisson
from utils.func import make_round_views, make_sparse_camera
from utils.mesh_utils import rot6d_to_rotmat, tensor2variable, to_py3d_mesh
from utils.project_mesh import get_cameras_list, multiview_color_projection, project_color  # NOQA
from utils.render import NormalsRenderer
from utils.smpl_util import SMPLX
from utils.video_utils import write_video
from instant_texture import Converter, convert_vertex_color_to_texture, build_lama_model, lama_model_path, lama_config_path  # NOQA
from avatarizer import avatarizer
from econdataset import SMPLDataset
from pytorch3d.transforms import euler_angles_to_matrix, axis_angle_to_matrix

bg_color = np.array([1, 1, 1])


def split_mv(mv):
    rot = mv[:, :3, :3]
    trans = mv[:, :3, 3:4]
    rot = rot[(1, 5), :, :].clone()
    trans = trans[(1, 5), :, :].clone()
    return rot, trans


def merge_mv(rot6d, trans, raw_mv):
    rot = rot6d
    # rot = rot6d_to_rotmat(rot6d)
    mv34 = torch.concat((rot, trans), dim=-1)
    additional_row = torch.tensor([[0, 0, 0, 1]], dtype=mv34.dtype, device=rot6d.device, requires_grad=False)
    additional_row = additional_row.unsqueeze(0).expand(mv34.shape[0], -1, -1)
    mv = torch.cat((mv34, additional_row), dim=1)
    mv = torch.stack((raw_mv[0], mv[0], raw_mv[2], raw_mv[3], raw_mv[4], mv[1]), dim=0)
    return mv


def angles_to_rotation_y(angles_degrees, device='cuda:0'):
    angles_rad = torch.tensor(angles_degrees, device=device) * math.pi / 180.0

    cos_theta = torch.cos(angles_rad)
    sin_theta = torch.sin(angles_rad)

    zeros = torch.zeros_like(cos_theta)
    ones = torch.ones_like(cos_theta)

    rotation_matrices = torch.stack([torch.stack([cos_theta, zeros, -sin_theta], dim=1), torch.stack([zeros, ones, zeros], dim=1), torch.stack([sin_theta, zeros, cos_theta], dim=1)], dim=1)

    return rotation_matrices


def rotation_to_angles_y(rotation_matrices, device='cuda:0'):
    cos_theta = rotation_matrices[:, 0, 0]
    sin_theta = rotation_matrices[:, 2, 0]
    angles_rad = torch.atan2(sin_theta, cos_theta)

    angles_degrees = angles_rad * 180.0 / math.pi

    # 确保角度在[-180, 180]范围内
    angles_degrees = torch.where(angles_degrees > 180, angles_degrees - 360, angles_degrees)
    angles_degrees = torch.where(angles_degrees < -180, angles_degrees + 360, angles_degrees)

    return angles_degrees


def remove_long_edge(mesh, opath, depth = 10):
    import open3d as o3d
    pcd_path = opath.replace(".obj", ".ply")
    assert (mesh.vertex_normals.shape[1] == 3)
    mesh.export(pcd_path)
    pcl = o3d.io.read_point_cloud(pcd_path)

    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
        mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
            pcl, depth=depth, n_threads=6
        )

    mesh2 = trimesh.Trimesh(np.array(mesh.vertices), np.array(mesh.triangles))
    if os.path.isfile(pcd_path):
        os.unlink(pcd_path)

    return mesh2



class colorModel(nn.Module):
    def __init__(self, renderer, v, f, c):
        super().__init__()
        self.renderer = renderer
        self.v = v
        self.f = f
        self.colors = nn.Parameter(c, requires_grad=True)
        self.bg_color = torch.from_numpy(bg_color).float().to(self.colors.device)

    def forward(self, return_mask=False):
        rgba = self.renderer.render(self.v, self.f, colors=self.colors)
        if return_mask:
            return rgba
        else:
            mask = rgba[..., 3:]
            return rgba[..., :3] * mask + self.bg_color * (1 - mask)


def scale_mesh(vert):
    min_bbox, max_bbox = vert.min(0)[0], vert.max(0)[0]
    center = (min_bbox + max_bbox) / 2
    offset = -center
    vert = vert + offset

    max_dist = torch.max(torch.sqrt(torch.sum(vert**2, dim=1)))
    scale = 1.0 / max_dist
    return scale, offset


def save_mesh(save_name, vertices, faces, color=None):
    v = vertices.detach().cpu().numpy().squeeze()
    f = faces.detach().cpu().numpy().squeeze()
    vertex_colors = (color.detach().cpu().numpy() * 255).astype(np.uint8) if color is not None else None
    trimesh.Trimesh(v, f, vertex_colors=vertex_colors, process=False, maintain_order=True).export(save_name)


class ReMesh:
    def __init__(self, opt, econ_dataset: SMPLDataset):
        print(f"ReMesh init {opt}")
        self.opt = opt
        self.device = torch.device(f"cuda:{opt.gpu_id}" if torch.cuda.is_available() else "cpu")
        self.num_view = opt.num_view

        self.res_path = opt.res_path
        os.makedirs(self.res_path, exist_ok=True)
        self.resolution = opt.resolution
        self.views = ['front_face', 'front_right', 'right', 'back', 'left', 'front_left']
        self.weights = torch.Tensor([1.0, 0.1, 0.8, 1.0, 0.8, 0.1]).view(6, 1, 1, 1).to(self.device)

        self.renderer: NormalsRenderer = self.prepare_render()
        # pose prediction
        self.econ_dataset = econ_dataset
        self.smplx_face = torch.Tensor(econ_dataset.faces.astype(np.int64)).long().to(self.device)
        self.lama_model = build_lama_model(lama_config_path, lama_model_path, device=self.device)

    def prepare_render(self):
        ### ------------------- prepare camera and renderer----------------------
        # make_sparse_camera 角度是 0, 45, 90, 180, -90, -45. 可以使用angles_to_rotation_y来生成,测试数据是一样的。
        mv, proj = make_sparse_camera(self.opt.cam_path, self.opt.scale, views=[0, 1, 2, 4, 6, 7], device=self.device)  # 默认是一个正交相机,  #  mv[0, :3, :3]是一个旋转矩阵 需要对他做优化
        renderer = NormalsRenderer(mv, proj, [self.resolution, self.resolution], device=self.device)
        return renderer

    def proj_texture(self, fused_images, vertices, faces):
        mesh = to_py3d_mesh(vertices, faces)
        mesh = mesh.to(self.device)
        camera_focal = 1 / 2
        cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal)
        mesh = multiview_color_projection(
            mesh,
            fused_images,
            camera_focal=camera_focal,
            resolution=self.resolution,
            weights=self.weights.squeeze().cpu().numpy(),
            device=self.device,
            complete_unseen=True,
            confidence_threshold=0.2,
            cameras_list=cameras_list,
        )
        return mesh

    def get_invisible_idx(self, imgs, vertices, faces):
        mesh = to_py3d_mesh(vertices, faces)
        mesh = mesh.to(self.device)
        camera_focal = 1 / 2
        if self.num_view == 6:
            cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal)
        elif self.num_view == 4:
            cameras_list = get_cameras_list([0, 90, 180, 270], device=self.device, focal=camera_focal)
        valid_vert_id = []
        vertices_colors = torch.zeros((vertices.shape[0], 3)).float().to(self.device)
        valid_cnt = torch.zeros((vertices.shape[0])).to(self.device)
        for cam, img, weight in zip(cameras_list, imgs, self.weights.squeeze()):
            ret = project_color(mesh, cam, img, eps=0.01, resolution=self.resolution, device=self.device)
            # print(ret['valid_colors'].shape)
            valid_cnt[ret['valid_verts']] += weight
            vertices_colors[ret['valid_verts']] += ret['valid_colors'] * weight
        valid_mask = valid_cnt > 1
        invalid_mask = valid_cnt < 1
        vertices_colors[valid_mask] /= valid_cnt[valid_mask][:, None]

        # visibility
        invisible_vert = valid_cnt < 1
        invisible_vert_indices = torch.nonzero(invisible_vert).squeeze()
        # vertices_colors[invalid_vert] = torch.tensor([1.0, 0.0, 0.0]).float().to("cuda")
        return vertices_colors, invisible_vert_indices

    def inpaint_missed_colors(self, all_vertices, all_colors, missing_indices):
        all_vertices = all_vertices.detach().cpu().numpy()
        all_colors = all_colors.detach().cpu().numpy()
        missing_indices = missing_indices.detach().cpu().numpy()

        non_missing_indices = np.setdiff1d(np.arange(len(all_vertices)), missing_indices)

        kdtree = KDTree(all_vertices[non_missing_indices])

        for missing_index in missing_indices:
            missing_vertex = all_vertices[missing_index]

            _, nearest_index = kdtree.query(missing_vertex.reshape(1, -1))

            interpolated_color = all_colors[non_missing_indices[nearest_index]]

            all_colors[missing_index] = interpolated_color

        return torch.from_numpy(all_colors).to(self.device)

    def load_training_data(self, case):
        ###------------------ load target images -------------------------------
        kernal = torch.ones(3, 3)
        erode_iters = 2
        normals = []
        masks = []
        colors = []
        for idx, view in enumerate(self.views):
            # for idx  in [0,2,3,4]:
            normal = Image.open(f'{self.opt.mv_path}/{case}/normals_{view}_masked.png')
            # normal = Image.open(f'{data_path}/{case}/normals/{idx:02d}_rgba.png')
            normal = normal.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR)
            normal = np.array(normal).astype(np.float32) / 255.0
            mask = normal[..., 3:]  # alpha
            mask_troch = torch.from_numpy(mask).unsqueeze(0)
            for _ in range(erode_iters):
                mask_torch = kornia.morphology.erosion(mask_troch, kernal)
            mask_erode = mask_torch.squeeze(0).numpy()
            masks.append(mask_erode)
            normal = normal[..., :3] * mask_erode
            normals.append(normal)

            color = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png')
            color = color.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR)
            color = np.array(color).astype(np.float32) / 255.0
            color_mask = color[..., 3:]  # alpha
            # color_dilate = color[..., :3] * color_mask  + bg_color * (1 - color_mask)
            color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode)
            colors.append(color_dilate)

        masks = np.stack(masks, 0)
        masks = torch.from_numpy(masks).to(self.device)
        normals = np.stack(normals, 0)
        target_normals = torch.from_numpy(normals).to(self.device)
        colors = np.stack(colors, 0)
        target_colors = torch.from_numpy(colors).to(self.device)
        return masks, target_colors, target_normals

    def preprocess(self, color_pils, normal_pils):
        ###------------------ load target images -------------------------------
        kernal = torch.ones(3, 3)
        erode_iters = 2
        normals = []
        masks = []
        colors = []
        for normal, color in zip(normal_pils[:6], color_pils[:6]):
            normal = normal.resize((self.resolution, self.resolution), Image.BILINEAR)
            normal = np.array(normal).astype(np.float32) / 255.0
            mask = normal[..., 3:]  # alpha
            mask_troch = torch.from_numpy(mask).unsqueeze(0)
            for _ in range(erode_iters):
                mask_torch = kornia.morphology.erosion(mask_troch, kernal)
            mask_erode = mask_torch.squeeze(0).numpy()
            masks.append(mask_erode)
            normal = normal[..., :3] * mask_erode
            normals.append(normal)

            color = color.resize((self.resolution, self.resolution), Image.BILINEAR)
            color = np.array(color).astype(np.float32) / 255.0
            color_mask = color[..., 3:]  # alpha
            # color_dilate = color[..., :3] * color_mask  + bg_color * (1 - color_mask)
            color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode)
            colors.append(color_dilate)

        masks = np.stack(masks, 0)
        masks = torch.from_numpy(masks).to(self.device)
        normals = np.stack(normals, 0)
        target_normals = torch.from_numpy(normals).to(self.device)
        colors = np.stack(colors, 0)
        target_colors = torch.from_numpy(colors).to(self.device)
        return masks, target_colors, target_normals

    def optimize_case(self, case, pose, clr_img, nrm_img, opti_texture=True):
        case_path = self.res_path
        os.makedirs(case_path, exist_ok=True)
        self.renderer: NormalsRenderer = self.prepare_render()  # 重新置render

        if clr_img is not None:
            masks, target_colors, target_normals = self.preprocess(clr_img, nrm_img)  # 拼接起来  6 x 1024 x 1024 x 1;  6 x 1024 x 1024 x 3 ; 6 x 1024 x 1024 x 3
        else:
            masks, target_colors, target_normals = self.load_training_data(case)

        torchvision.utils.save_image(target_normals.permute(0, 3, 1, 2), os.path.join(case_path, "optimize_case_input_target_normals.jpg"))
        torchvision.utils.save_image(target_colors.permute(0, 3, 1, 2), os.path.join(case_path, "optimize_case_input_target_colors.jpg"))
        torchvision.utils.save_image(masks.permute(0, 3, 1, 2), os.path.join(case_path, "optimize_case_input_masks.jpg"))

        # rotation
        rz = R.from_euler('z', 180, degrees=True).as_matrix()  # [-1, 0, 0] [0, -1, 0] [0, 0, 1]
        ry = R.from_euler('y', 180, degrees=True).as_matrix()  # [-1, 0, 0] [0, 1, 0] [0, 0, -1]
        rz = torch.from_numpy(rz).float().to(self.device)
        ry = torch.from_numpy(ry).float().to(self.device)

        scale, offset = None, None

        global_orient = pose["global_orient"]  # pymaf_res[idx]['smplx_params']['body_pose'][:, :1, :, :2].to(device).reshape(1, 1, -1) # data["global_orient"]  # 1 x 1 x 6
        body_pose = pose["body_pose"]  # pymaf_res[idx]['smplx_params']['body_pose'][:, 1:22, :, :2].to(device).reshape(1, 21, -1) # data["body_pose"]           # 1 x 21 x 6
        left_hand_pose = pose["left_hand_pose"]  # pymaf_res[idx]['smplx_params']['left_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1)                   # 1 x 15 x 3 x 3
        right_hand_pose = pose["right_hand_pose"]  # pymaf_res[idx]['smplx_params']['right_hand_pose'][:, :, :, :2].to(device).reshape(1, 15, -1)                # 1 x 15 x 3 x 3
        beta = pose["betas"]  # 1 x 200

        # The optimizer and variables
        optimed_pose = torch.tensor(body_pose, device=self.device, requires_grad=True)  # [1,21,6]
        optimed_trans = torch.tensor(pose["trans"], device=self.device, requires_grad=True)  # [3]
        optimed_betas = torch.tensor(beta, device=self.device, requires_grad=True)  # [1,200]
        optimed_orient = torch.tensor(global_orient, device=self.device, requires_grad=True)  # [1,1,6]
        optimed_rhand = torch.tensor(right_hand_pose, device=self.device, requires_grad=True)
        optimed_lhand = torch.tensor(left_hand_pose, device=self.device, requires_grad=True)

        # split mv use for opt
        mv_rot, mv_trans = split_mv(self.renderer._mv.clone())
        optimed_mv_rot = torch.tensor(mv_rot, device=self.device, requires_grad=True)

        optimed_params = [
            {'params': [optimed_lhand, optimed_rhand], 'lr': 1e-3},
            {'params': [optimed_betas, optimed_trans, optimed_orient, optimed_pose], 'lr': 3e-3},
        ]
        optimizer_smpl = torch.optim.Adam(
            optimed_params,
            amsgrad=True,
        )
        scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_smpl,
            mode="min",
            factor=0.5,
            verbose=0,
            min_lr=1e-5,
            patience=5,
        )

        optimizer_pose = torch.optim.Adam(
            [{'params': [optimed_mv_rot], 'lr': 1e-2}],
            amsgrad=True,
        )
        scheduler_pose = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_pose,
            mode="min",
            factor=0.5,
            verbose=0,
            min_lr=1e-5,
            patience=5,
        )

###############  step1 opt smpl #################
        smpl_steps = 200
        smpl_loss = torch.Tensor([0.0]).to(self.device)
        pbar = tqdm(range(smpl_steps), mininterval=1)
        for i in pbar:         
            optimizer_smpl.zero_grad()
            optimizer_pose.zero_grad()

            # 6d_rot to rot_mat
            optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).unsqueeze(0)
            optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).unsqueeze(0)

            smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model(
                shape_params=optimed_betas,
                expression_params=tensor2variable(pose["exp"], self.device),
                body_pose=optimed_pose_mat,
                global_pose=optimed_orient_mat,
                jaw_pose=tensor2variable(pose["jaw_pose"], self.device),
                left_hand_pose=optimed_lhand,
                right_hand_pose=optimed_rhand,
            )

            smpl_verts = smpl_verts + optimed_trans

            v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T)
            if scale is None:
                scale, offset = scale_mesh(v_smpl.detach())  # scale, center
            v_smpl = (v_smpl + offset) * scale * 2  # 2 std

            normals = calc_vertex_normals(v_smpl, self.smplx_face)
            mv = merge_mv(optimed_mv_rot, mv_trans, self.renderer._mv)
            nrm = self.renderer.render(v_smpl, self.smplx_face, normals=normals, mv=mv)

            masks_ = nrm[..., 3:]
            smpl_mask_loss = ((masks_ - masks) * self.weights).abs().mean()
            smpl_nrm_loss = ((nrm[..., :3] - target_normals) * self.weights).abs().mean()

            smpl_loss = smpl_mask_loss + smpl_nrm_loss
            # smpl_loss =  smpl_mask_loss
            smpl_loss.backward()
            if i < 150:
                optimizer_smpl.step()
                scheduler_smpl.step(smpl_loss)

            if i % 5 == 0 or i >= 150:
                optimizer_pose.step()
                scheduler_pose.step(smpl_loss)

            pbar.set_description(f"step1 loss: {smpl_loss.item():.4f}")
            if (i == 0 or i == pbar.total - 1) and self.opt.debug:
                opath = f'{case_path}/{case}_step1_opt_{i:03d}.obj'
                save_mesh(opath, v_smpl, self.smplx_face)
                torchvision.utils.save_image(nrm[..., :3].clone().permute(0, 3, 1, 2), opath.replace(".obj", "_normal.jpg"), nrow=6)

        mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy())
        if self.opt.debug:
            mesh_smpl.export(f'{case_path}/{case}_smpl_opt.obj')
            pose["faces"] = self.smplx_face.detach().cpu()
            pose_raw = copy.deepcopy(pose)
            pose["betas"] = optimed_betas.detach()
            pose["body_pose"] = optimed_pose.detach()
            pose["global_orient"] = optimed_orient.detach()
            pose["left_hand_pose"] = optimed_lhand.detach()
            pose["right_hand_pose"] = optimed_rhand.detach()
            pose["scale_mesh_scale"] = scale
            pose["scale_mesh_offset"] = offset
            pose["trans"] = optimed_trans.detach()
            smpl_param_path = f'{case_path}/{case}_step1_smpl.pkl'
            pickle.dump((pose, pose_raw), open(smpl_param_path, 'wb'))

        self.renderer._mv = mv.detach()
        print("优化前后的pose角度", rotation_to_angles_y(mv_rot), rotation_to_angles_y(optimed_mv_rot))
        nrm_opt = MeshOptimizer(v_smpl.detach(), self.smplx_face.detach(), edge_len_lims=[0.01, 0.1])
        vertices, faces = nrm_opt.vertices, nrm_opt.faces

# ### step2----------------------- optimization iterations-------------------------------------
        pbar = tqdm(range(self.opt.iters), mininterval=1)
        for i in pbar:
            nrm_opt.zero_grad()

            normals = calc_vertex_normals(vertices, faces)
            nrm = self.renderer.render(vertices, faces, normals=normals)
            normals = nrm[..., :3]
            loss = ((normals - target_normals) * self.weights).abs().mean()
            alpha = nrm[..., 3:]
            loss += ((alpha - masks) * self.weights).abs().mean()

            loss.backward()

            nrm_opt.step()
            torch.cuda.synchronize()
            vertices, faces = nrm_opt.remesh()

            pbar.set_description(f"step2 loss: {loss.item():.4f}")
            opath = f'{case_path}/{case}_step2_opt_{i:03d}.obj'
            if self.opt.debug and (i % self.opt.snapshot_step == 0 or i == pbar.total - 1):
                save_mesh(opath, vertices, faces)
                torchvision.utils.save_image(normals.permute(0, 3, 1, 2), opath.replace(".obj", "_normal.jpg"), nrow=6)

        mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy())
        if True:  # self.opt.debug:
            remeshed_obj_path = f'{case_path}/{case}_remeshed.obj'
            mesh_remeshed.export(remeshed_obj_path)
            torchvision.utils.save_image(target_normals.permute(0, 3, 1, 2), opath.replace(".obj", "_target_normals.jpg"), nrow=6)

        # refine depth
        from lib.dataset.mesh_util import get_closest_points
        mesh_remeshed1 = remove_long_edge(mesh_remeshed, remeshed_obj_path)
        mesh_list = mesh_remeshed1.split(only_watertight=False)
        mesh_list = sorted(mesh_list, key=lambda x: -x.vertices.shape[0])
        largest_mesh = mesh_list[0].copy()

        mesh_list1 = []
        for k, mesh in enumerate(mesh_list):
            if mesh.vertices.shape[0] > 100:
                # mesh.export(f"{case_path}/part_{k}.obj")
                mesh.vertices[:, 2] += np.median(largest_mesh.vertices[:, 2]) - np.median(mesh.vertices[:, 2])
                mesh_list1.append(mesh)
                # mesh.export(f"{case_path}/part_{k}_1.obj")
            

        mesh_remeshed2 = sum(mesh_list1)
        mesh_remeshed2 = poisson(mesh_remeshed2, remeshed_obj_path + ".obj", decimation=False, depth=9, face_count=100000, only_largest=False)
        mesh_remeshed2.export(remeshed_obj_path)
        target_vertices = torch.tensor(mesh_remeshed2.vertices, dtype=torch.float32, requires_grad=False).to(self.device)

##############3  step3  继续优化smpl的参数 mesh_remeshed2更加的贴合。 #################
        pbar = tqdm(range(smpl_steps // 2), mininterval=1)
        for i in pbar:          # step3
            optimizer_smpl.zero_grad()
            optimizer_pose.zero_grad()

            # 6d_rot to rot_mat
            optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1, 6)).unsqueeze(0)
            optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1, 6)).unsqueeze(0)

            smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model(
                shape_params=optimed_betas,
                expression_params=tensor2variable(pose["exp"], self.device),
                body_pose=optimed_pose_mat,
                global_pose=optimed_orient_mat,
                jaw_pose=tensor2variable(pose["jaw_pose"], self.device),
                left_hand_pose=optimed_lhand,
                right_hand_pose=optimed_rhand,
            )

            smpl_verts = smpl_verts + optimed_trans

            v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T)
            if scale is None:
                scale, offset = scale_mesh(v_smpl.detach())  # scale, center
            v_smpl = (v_smpl + offset) * scale * 2  # 2 std

            loss, _ = pytorch3d.loss.chamfer_distance(v_smpl[None, :, :], target_vertices[None, :, :])
            loss.backward()
            optimizer_smpl.step()

            pbar.set_description(f"step3 loss: {loss.item():.4f}")
            if self.opt.debug and (i % self.opt.snapshot_step == 0 or i == pbar.total - 1):
                opath = f'{case_path}/{case}_step3_opt_{i:03d}.obj'
                save_mesh(opath, v_smpl, self.smplx_face)

        mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy())
        if True:  # self.opt.debug:
            mesh_smpl.export(f'{case_path}/{case}_smpl_opt.obj')
            pose["faces"] = self.smplx_face.detach().cpu()
            pose_raw = copy.deepcopy(pose)
            pose["betas"] = optimed_betas.detach()
            pose["body_pose"] = optimed_pose.detach()
            pose["global_orient"] = optimed_orient.detach()
            pose["left_hand_pose"] = optimed_lhand.detach()
            pose["right_hand_pose"] = optimed_rhand.detach()
            pose["scale_mesh_scale"] = scale
            pose["scale_mesh_offset"] = offset
            pose["trans"] = optimed_trans.detach()
            smpl_param_path = f'{case_path}/{case}_smpl.pkl'
            pickle.dump(pose, open(smpl_param_path, 'wb'))

        # avatarizer
        # econ_pose, econ_dict = avatarizer(remeshed_obj_path, smpl_param_path, f"{case_path}/avatar", with_dress=False, replace_hand=False)
        econ_pose = trimesh.load_mesh(remeshed_obj_path, process=False, maintain_order=True)
        econ_dict = {}

        # 顶点模型
        vertices = torch.from_numpy(econ_pose.vertices).float().to(self.device)
        faces = torch.from_numpy(econ_pose.faces).long().to(self.device)
        masked_color = []
        for tmp in clr_img:
            # tmp = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png')
            tmp = tmp.resize((self.resolution, self.resolution), Image.BILINEAR)
            tmp = np.array(tmp).astype(np.float32) / 255.0
            masked_color.append(torch.from_numpy(tmp).permute(2, 0, 1).to(self.device))

        meshes = self.proj_texture(masked_color, vertices, faces)
        vertices = meshes.verts_packed().float()
        faces = meshes.faces_packed().long()
        colors = meshes.textures.verts_features_packed().float()

        opath = f"{case_path}/{case}.obj"
        save_mesh(opath, vertices, faces, colors)
        if self.opt.debug:
            self.evaluate(vertices, colors, faces, save_path=f'{opath}.mp4', save_nrm=True)

        # 转为纹理模型``
        print("convert_vertex_color_to_texture ...")
        vmapping, new_faces = convert_vertex_color_to_texture(opath, opath, texture_size=2048, lama_model=self.lama_model)
        econ_pkl = f"{case_path}/{case}_bind.pkl"
        econ_dict["vmapping"] = vmapping                    # 存储uv后的点的映射
        econ_dict["new_faces"] = new_faces                  # 新的faces
        pickle.dump(econ_dict, open(econ_pkl, "wb"))

        # 保存结果数据
        zip_path = opath + ".zip"
        material_files = glob(os.path.join(os.path.dirname(opath), "material*"))
        extra_files = [remeshed_obj_path, smpl_param_path, econ_pkl]
        with zipfile.ZipFile(zip_path, 'w') as zipf:
            for file in [opath] + material_files + extra_files:
                zipf.write(file, os.path.basename(file))

        return zip_path

    #### replace hand
    def replace_hand(self, pose, mesh_smpl, mesh_remeshed):
        smpl_data = SMPLX()
        if self.opt.replace_hand and True in pose['hands_visibility'][0]:
            hand_mask = torch.zeros(smpl_data.smplx_verts.shape[0])

            if pose['hands_visibility'][0][0]:
                hand_mask.index_fill_(0, torch.tensor(smpl_data.smplx_mano_vid_dict["left_hand"]), 1.0)
            if pose['hands_visibility'][0][1]:
                hand_mask.index_fill_(0, torch.tensor(smpl_data.smplx_mano_vid_dict["right_hand"]), 1.0)

            hand_mesh = apply_vertex_mask(mesh_smpl.copy(), hand_mask)
            body_mesh = part_removal(mesh_remeshed.copy(), hand_mesh, 0.08, self.device, mesh_smpl.copy(), region="hand")
            final = poisson(sum([hand_mesh, body_mesh]), f'temp/temp_final.obj', 10, False)
        else:
            final = poisson(mesh_remeshed, f'temp/temp_final.obj', 10, False)

        return final

    def evaluate(self, target_vertices, target_colors, target_faces, save_path=None, save_nrm=False):
        mv, proj = make_round_views(60, self.opt.scale, device=self.device)
        renderer = NormalsRenderer(mv, proj, [512, 512], device=self.device)

        target_images = renderer.render(target_vertices, target_faces, colors=target_colors)
        target_images = target_images.detach().cpu().numpy()
        target_images = target_images[..., :3] * target_images[..., 3:4] + bg_color * (1 - target_images[..., 3:4])
        target_images = (target_images.clip(0, 1) * 255).astype(np.uint8)

        if save_nrm:
            target_normals = calc_vertex_normals(target_vertices, target_faces)
            # target_normals[:, 2] *= -1
            target_normals = renderer.render(target_vertices, target_faces, normals=target_normals)
            target_normals = target_normals.detach().cpu().numpy()
            target_normals = target_normals[..., :3] * target_normals[..., 3:4] + bg_color * (1 - target_normals[..., 3:4])
            target_normals = (target_normals.clip(0, 1) * 255).astype(np.uint8)
            frames = [np.concatenate([img, nrm], 1) for img, nrm in zip(target_images, target_normals)]
        else:
            frames = [img for img in target_images]
        if save_path is not None:
            write_video(frames, fps=25, save_path=save_path)
        return frames

    def run(self):
        cases = sorted(os.listdir(self.opt.imgs_path))
        for idx in range(len(cases)):
            case = cases[idx].split('.')[0]
            print(f'Processing {case}')
            pose = self.econ_dataset.__getitem__(idx)
            v, f, c = self.optimize_case(case, pose, None, None, opti_texture=True)
            self.evaluate(v, c, f, save_path=f'{self.opt.res_path}/{case}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True)


if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--config", help="path to the yaml configs file", default='config.yaml')
    # args, extras = parser.parse_known_args()
    # opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
    cfg = OmegaConf.load("args.yaml")
    dataset_param = {'image_dir': cfg.validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'}
    econdata = SMPLDataset(dataset_param, device='cuda')
    EHuman = ReMesh(cfg.recon_opt, econdata)
    EHuman.case_path = "examples/3D人体图片0120_out/joker/recon"
    (scene, pose, colors, normals) = pickle.load(open("examples/3D人体图片0120_out4/沙滩风外国小男孩1/carving_input.pkl", "rb"))
    zip_path = EHuman.optimize_case(scene, pose, colors[:6], normals[:6])
    print(zip_path)

三角网格优化:

delaunay三角网格优化中的相关问题_delaunay三角网转mesh-CSDN博客

remesh 重训网格化 2019年

重新网格化(Remesh)-腾讯云开发者社区-腾讯云

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值