diff --git a/README.md b/README.md index bd21011..d897c25 100644 --- a/README.md +++ b/README.md @@ -63,13 +63,13 @@ Please refer to [env/README.md](env/README.md) for detailed environment setup in Run demo with different modes: ```bash # Reconstruct the input image -python run_demo.py --render_mode reconstruct +python run_demo.py --render_mode reconstruct --low_ram # Generate novel poses (animation) -python run_demo.py --render_mode novel_pose +python run_demo.py --render_mode novel_pose --low_ram # Generate 360-degree view -python run_demo.py --render_mode novel_pose_A +python run_demo.py --render_mode novel_pose_A --low_ram ``` ### Training diff --git a/configs/idol_v0.yaml b/configs/idol_v0.yaml index 27d3130..bdfceef 100644 --- a/configs/idol_v0.yaml +++ b/configs/idol_v0.yaml @@ -1,144 +1,144 @@ - -debug: True -# code_size: [32, 256, 256] -code_size: [32, 1024, 1024] -model: - # base_learning_rate: 2.0e-04 # yy Need to check - target: lib.SapiensGS_SA_v1 - params: - # optimizer add - # use_bf16: true - max_steps: 100_000 - warmup_steps: 10_000 #12_000 - use_checkpoint: true - lambda_depth_tv: 0.05 - lambda_lpips: 10 #2.0 - lambda_mse: 20 #1.0 - lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1 - neck_learning_rate: 5e-4 - decoder_learning_rate: 5e-4 - - - output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder - - loss_coef: 0.5 - init_iter: 500 - scale_weight: 0.01 - smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json' - - code_reshape: [32, 96, 96] - patch_size: 1 - code_activation: - type: tanh - mean: 0.0 - std: 0.5 - clip_range: 2 - grid_size: 64 - encoder: - target: lib.models.sapiens.SapiensWrapper_ts - params: - # model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2 - model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2 - layer_num: 40 - img_size: [1024, 736] - freeze: True - neck: - target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version - params: - patch_size: 4 #4, - in_chans: 32 #32, # the uv code dims - num_patches: 9216 #4096 #num_patches #,#4096, # 16*16 - embed_dim: 1536 # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs - decoder_embed_dim: 1536 # 1024 - decoder_depth: 16 # 8 - decoder_num_heads: 16 #16, - total_num_hidden_states: 40 - mlp_ratio: 4. - decoder: - target: lib.models.decoders.UVNDecoder_gender - params: - interp_mode: bilinear - base_layers: [16, 64] - density_layers: [64, 1] - color_layers: [16, 128, 9] - offset_layers: [64, 3] - use_dir_enc: false - dir_layers: [16, 64] - activation: silu - bg_color: 1 - sigma_activation: sigmoid - sigmoid_saturation: 0.001 - gender: neutral - is_sub2: true ## update, make it into 10w gs points - multires: 0 - image_size: [640, 896] - superres: false - focal: 1120 - up_cnn_in_channels: 1536 # be the same as decoder_embed_dim - reshape_type: VitHead - vithead_param: - in_channels: 1536 # be the same as decoder_embed_dim - out_channels: 32 - deconv_out_channels: [512, 512, 512, 256] - deconv_kernel_sizes: [4, 4, 4, 4] - conv_out_channels: [128, 128] - conv_kernel_sizes: [3, 3] - fix_sigma: true - -dataset: - target: lib.datasets.dataloader.DataModuleFromConfig - params: - batch_size: 1 #16 # 6 for lpips - num_workers: 2 #2 - # working when in debug mode - debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy - - train: - target: lib.datasets.AvatarDataset - params: - data_prefix: None - - cache_path: [ - ./processed_data/deepfashion_train_140_local.npy, - ./processed_data/flux_batch1_5000_train_140_local.npy - ] - - specific_observation_num: 5 - better_range: true - first_is_front: true - if_include_video_ref_img: true - prob_include_video_ref_img: 0.5 - img_res: [640, 896] - validation: - target: lib.datasets.AvatarDataset - params: - data_prefix: None - load_imgs: true - specific_observation_num: 5 - better_range: true - first_is_front: true - img_res: [640, 896] - cache_path: [ - ./processed_data/deepfashion_val_10_local.npy, - ./processed_data/flux_batch1_5000_val_10_local.npy - ] - - - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 4000 #2000 - save_top_k: -1 - save_last: true - monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa - mode: "min" - filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}' - callbacks: {} - trainer: - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - gradient_clip_val: 10.0 - max_steps: 80000 - check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch + +debug: True +# code_size: [32, 256, 256] +code_size: [32, 1024, 1024] +model: + # base_learning_rate: 2.0e-04 # yy Need to check + target: lib.SapiensGS_SA_v1 + params: + # optimizer add + # use_bf16: true + max_steps: 100_000 + warmup_steps: 10_000 #12_000 + use_checkpoint: true + lambda_depth_tv: 0.05 + lambda_lpips: 10 #2.0 + lambda_mse: 20 #1.0 + lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1 + neck_learning_rate: 5e-4 + decoder_learning_rate: 5e-4 + + + output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder + + loss_coef: 0.5 + init_iter: 500 + scale_weight: 0.01 + smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json' + + code_reshape: [32, 96, 96] + patch_size: 1 + code_activation: + type: tanh + mean: 0.0 + std: 0.5 + clip_range: 2 + grid_size: 64 + encoder: + target: lib.models.sapiens.SapiensWrapper_ts + params: + model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2 + # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2 + layer_num: 40 + img_size: [1024, 736] + freeze: True + neck: + target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version + params: + patch_size: 4 #4, + in_chans: 32 #32, # the uv code dims + num_patches: 9216 #4096 #num_patches #,#4096, # 16*16 + embed_dim: 1536 # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs + decoder_embed_dim: 1536 # 1024 + decoder_depth: 16 # 8 + decoder_num_heads: 16 #16, + total_num_hidden_states: 40 + mlp_ratio: 4. + decoder: + target: lib.models.decoders.UVNDecoder_gender + params: + interp_mode: bilinear + base_layers: [16, 64] + density_layers: [64, 1] + color_layers: [16, 128, 9] + offset_layers: [64, 3] + use_dir_enc: false + dir_layers: [16, 64] + activation: silu + bg_color: 1 + sigma_activation: sigmoid + sigmoid_saturation: 0.001 + gender: neutral + is_sub2: true ## update, make it into 10w gs points + multires: 0 + image_size: [640, 896] + superres: false + focal: 1120 + up_cnn_in_channels: 1536 # be the same as decoder_embed_dim + reshape_type: VitHead + vithead_param: + in_channels: 1536 # be the same as decoder_embed_dim + out_channels: 32 + deconv_out_channels: [512, 512, 512, 256] + deconv_kernel_sizes: [4, 4, 4, 4] + conv_out_channels: [128, 128] + conv_kernel_sizes: [3, 3] + fix_sigma: true + +dataset: + target: lib.datasets.dataloader.DataModuleFromConfig + params: + batch_size: 1 #16 # 6 for lpips + num_workers: 2 #2 + # working when in debug mode + debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy + + train: + target: lib.datasets.AvatarDataset + params: + data_prefix: None + + cache_path: [ + ./processed_data/deepfashion_train_140_local.npy, + ./processed_data/flux_batch1_5000_train_140_local.npy + ] + + specific_observation_num: 5 + better_range: true + first_is_front: true + if_include_video_ref_img: true + prob_include_video_ref_img: 0.5 + img_res: [640, 896] + validation: + target: lib.datasets.AvatarDataset + params: + data_prefix: None + load_imgs: true + specific_observation_num: 5 + better_range: true + first_is_front: true + img_res: [640, 896] + cache_path: [ + ./processed_data/deepfashion_val_10_local.npy, + ./processed_data/flux_batch1_5000_val_10_local.npy + ] + + + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 4000 #2000 + save_top_k: -1 + save_last: true + monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa + mode: "min" + filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}' + callbacks: {} + trainer: + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + gradient_clip_val: 10.0 + max_steps: 80000 + check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch benchmark: true \ No newline at end of file diff --git a/lib/humanlrm_wrapper_sa_v1.py b/lib/humanlrm_wrapper_sa_v1.py index a519450..2f9561c 100644 --- a/lib/humanlrm_wrapper_sa_v1.py +++ b/lib/humanlrm_wrapper_sa_v1.py @@ -1,774 +1,783 @@ - -import os -import math -import json -from torch.optim import Adam -from torch.nn.parallel.distributed import DistributedDataParallel -import torch -import torch.nn.functional as F -from torchvision.transforms import InterpolationMode -from torchvision.utils import make_grid, save_image -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure -import pytorch_lightning as pl -from pytorch_lightning.utilities.grads import grad_norm -from einops import rearrange, repeat - -from lib.utils.train_util import instantiate_from_config -from lib.ops.activation import TruncExp -import time -import matplotlib.pyplot as plt - -from PIL import Image - -import numpy as np -from lib.utils.train_util import main_print - -from typing import List, Optional, Tuple, Union -def get_1d_rotary_pos_embed( - dim: int, - pos: Union[torch.Tensor, int], - theta: float = 10000.0, - use_real=False, - linear_factor=1.0, - ntk_factor=1.0, - repeat_interleave_real=True, -): - """ - Precompute the frequency tensor for complex exponentials (cis) with given dimensions. - - This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end - index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 - data type. - - Args: - dim (`int`): Dimension of the frequency tensor. - pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar - theta (`float`, *optional*, defaults to 10000.0): - Scaling factor for frequency computation. Defaults to 10000.0. - use_real (`bool`, *optional*): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - linear_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the context extrapolation. Defaults to 1.0. - ntk_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. - repeat_interleave_real (`bool`, *optional*, defaults to `True`): - If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. - Otherwise, they are concateanted with themselves. - Returns: - `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] - """ - assert dim % 2 == 0 - - if isinstance(pos, int): - pos = torch.arange(pos) - theta = theta * ntk_factor - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] - t = pos # torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = freqs.to(device=t.device, dtype=t.dtype) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] - if use_real and repeat_interleave_real: - freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] - freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] - return freqs_cos, freqs_sin - elif use_real: - freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] - freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] - return freqs_cos, freqs_sin - else: - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] - return freqs_cis -class FluxPosEmbed(torch.nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: [int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - freqs_dtype = torch.float32 if is_mps else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - -class SapiensGS_SA_v1(pl.LightningModule): - - def __init__( - self, - encoder=dict(type='mmpretrain.VisionTransformer'), - neck=dict(type='mmpretrain.VisionTransformer'), - decoder=dict(), - diffusion_use_ema=True, - freeze_decoder=False, - image_cond=False, - code_permute=None, - code_reshape=None, - autocast_dtype=None, - ortho=True, - return_norm=False, - # reshape_type='reshape', # 'cnn' - code_size=None, - decoder_use_ema=None, - bg_color=1, - training_mode=None, # stage2's flag, default None for stage1 - - patch_size: int = 4, - - warmup_steps: int = 12_000, - use_checkpoint: bool = True, - lambda_depth_tv: float = 0.05, - lambda_lpips: float = 2.0, - lambda_mse: float = 1.0, - lambda_l1: float=0, - lambda_ssim: float=0, - neck_learning_rate: float = 5e-4, - decoder_learning_rate: float = 1e-3, - encoder_learning_rate: float=0, - max_steps: int = 100_000, - loss_coef: float = 0.5, - init_iter: int = 500, - lambda_offset: int = 50, # offset_weight: 50 - scale_weight: float = 0.01, - is_debug: bool = False, # if debug, then it will not returns lpips - code_activation: dict=None, - output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder - loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views - **kwargs - ): - super(SapiensGS_SA_v1, self).__init__() - ## ========== part -- Add the code to save this parameters for optimizers ======== - self.warmup_steps = warmup_steps - self.use_checkpoint = use_checkpoint - self.lambda_depth_tv = lambda_depth_tv - self.lambda_lpips = lambda_lpips - self.lambda_mse = lambda_mse - self.lambda_l1 = lambda_l1 - self.lambda_ssim = lambda_ssim - self.neck_learning_rate = neck_learning_rate - self.decoder_learning_rate = decoder_learning_rate - self.encoder_learning_rate = encoder_learning_rate - self.max_steps = max_steps - - self.loss_coef = loss_coef - self.init_iter = init_iter - self.lambda_offset = lambda_offset - self.scale_weight = scale_weight - - self.is_debug = is_debug - ## ========== end part ======== - - - - self.code_size = code_size - if code_activation['type'] == 'tanh': - self.code_activation = torch.nn.Tanh() - else: - self.code_activation = TruncExp() #build_module(code_activation) - # self.grid_size = grid_size - self.decoder = instantiate_from_config(decoder) - self.decoder_use_ema = decoder_use_ema - if decoder_use_ema: - raise NotImplementedError("decoder_use_ema has not been implemented") - if self.decoder_use_ema: - self.decoder_ema = deepcopy(self.decoder) - self.encoder = instantiate_from_config(encoder) - # get_obj_from_str(config["target"]) - - self.code_size = code_reshape - self.code_clip_range = [-1,1] - - # ============= begin config ============= - # transformer from class MAEPretrainDecoder(BaseModule): - # compress the token number of the uv code - self.patch_size = patch_size - self.code_patch_size = self.patch_size - self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling - self.num_patches = self.num_patches_axis ** 2 - self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type - self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type - - self.reshape_type = self.decoder.reshape_type - - - self.inputs_front_only = True - self.render_loss_all_view = True - self.if_include_video_ref_img = True - - self.training_mode = training_mode - - self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views) # normalize the weights - - - - # ========== config meaning =========== - self.neck = instantiate_from_config(neck) - - self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0) - self.freeze_decoder = freeze_decoder - if self.freeze_decoder: - self.decoder.requires_grad_(False) - if self.decoder_use_ema: - self.decoder_ema.requires_grad_(False) - self.image_cond = image_cond - self.code_permute = code_permute - self.code_reshape = code_reshape - self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \ - else self.code_size - self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \ - if code_permute is not None else None - - self.autocast_dtype = autocast_dtype - self.ortho = ortho - self.return_norm = return_norm - - '''add a flag for the skip connection from sapiens shallow layer''' - self.output_hidden_states = output_hidden_states - - ''' add the in-the-wild images visualization''' - if self.lambda_lpips > 0: - self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') - else: - self.lpips = None - - self.ssim = StructuralSimilarityIndexMeasure() - - - self.validation_step_outputs = [] - self.validation_step_code_outputs = [] # saving the code - self.validation_step_nvPose_outputs = [] - self.validation_metrics = [] - - - # loading the smplx for the nv pose - import json - import numpy as np - # evaluate the animation - smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json' - with open(smplx_path, 'r') as f: - smplx_pose_param = json.load(f) - smplx_param_list = [] - for par in smplx_pose_param['annotations']: - k = par['smplx_params'] - for i in k.keys(): - k[i] = np.array(k[i]) - left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, - -0.6652, -0.7290, 0.0084, -0.4818]) - betas = torch.zeros((10)) - smplx_param = \ - np.concatenate([np.array([1]), np.array([0,0.,0]), np.array([0, -1, 0])*k['root_orient'], \ - k['pose_body'],betas, \ - k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1) - # print(smplx_param.shape) - smplx_param_list.append(smplx_param) - smplx_params = np.concatenate(smplx_param_list, 0) - self.smplx_params = torch.Tensor(smplx_params).cuda() - def get_default_smplx_params(self): - A_pose = torch.Tensor([[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.1047, 0.0000, 0.0000, -0.1047, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7854, 0.0000, - 0.0000, 0.7854, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7470, 1.0966, - 0.0169, -0.0534, -0.0212, 0.0782, -0.0348, 0.0260, 0.0060, 0.0118, - -0.1117, -0.0429, 0.4164, -0.1088, 0.0660, 0.7562, 0.0964, 0.0909, - 0.1885, 0.1181, -0.0509, 0.5296, 0.1437, -0.0552, 0.7049, 0.0192, - 0.0923, 0.3379, 0.4570, 0.1963, 0.6255, 0.2147, 0.0660, 0.5069, - 0.3697, 0.0603, 0.0795, 0.1419, 0.0859, 0.6355, 0.3033, 0.0579, - 0.6314, 0.1761, 0.1321, 0.3734, -0.8510, -0.2769, 0.0915, 0.4998, - -0.0266, -0.0529, -0.5356, -0.0460, 0.2774, -0.1117, 0.0429, -0.4164, - -0.1088, -0.0660, -0.7562, 0.0964, -0.0909, -0.1885, 0.1181, 0.0509, - -0.5296, 0.1437, 0.0552, -0.7049, 0.0192, -0.0923, -0.3379, 0.4570, - -0.1963, -0.6255, 0.2147, -0.0660, -0.5069, 0.3697, -0.0603, -0.0795, - 0.1419, -0.0859, -0.6355, 0.3033, -0.0579, -0.6314, 0.1761, -0.1321, - -0.3734, -0.8510, 0.2769, -0.0915, 0.4998, 0.0266, 0.0529, -0.5356, - 0.0460, -0.2774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) - return A_pose - def forward_decoder(self, decoder, code, target_rgbs, cameras, - smpl_params=None, return_decoder_loss=False, init=False): - decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder - num_imgs = target_rgbs.shape[1] - outputs = decoder( - code, smpl_params, cameras, - num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False) - return outputs - - def on_fit_start(self): - if self.global_rank == 0: - os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) - os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) - os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True) - - - - def forward(self, data): - # print("iter") - - num_scenes = len(data['scene_id']) # 8 - if 'cond_imgs' in data: - cond_imgs = data['cond_imgs'] # (num_scenes, num_imgs, h, w, 3) - cond_intrinsics = data['cond_intrinsics'] # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy] - cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4) - smpl_params = data['cond_smpl_param'] # (num_scenes, c) - # if 'cond_norm' in data:cond_norm = data['cond_norm'] else: cond_norm = None - num_scenes, num_imgs, h, w, _ = cond_imgs.size() - cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1) - if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss - cameras = cameras[:,1:] - num_imgs = num_imgs - 1 - - - if self.inputs_front_only: # default setting to use the first image as input - inputs_img_idx = [0] - else: - raise NotImplementedError("inputs_front_only is False") - inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) # - - target_imgs = cond_imgs[:, 1:] - assert cameras.shape[1] == target_imgs.shape[1] - - - if self.is_debug: - try: - code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation - except Exception as e: # OOM - main_print(e) - code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device) - else: - code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation - - decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder - # uvmaps_decoder_gender's forward - output = decoder( - code, smpl_params, cameras, - num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset']) - - output['code'] = code - output['target_imgs'] = target_imgs - output['inputs_img'] = cond_imgs[:,[0],...] - - # for visualization - if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug: - overlay_imgs = 0.5 * target_imgs + 0.5 * output['image'] - overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c') - overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c') - overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy() - overlay_imgs = (overlay_imgs * 255).astype(np.uint8) - Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg') - - return output - - def forward_image_to_uv(self, inputs_img, is_training=True): - ''' - inputs_img: torch.Tensor, bs, 3, H, W - return - code : bs, 256, 256, 32 - ''' - if self.decoder_learning_rate <= 0: - with torch.no_grad(): - features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) - else: - features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) - - if self.ids_restore.device !=features_flatten.device: - self.ids_restore = self.ids_restore.to(features_flatten.device) - ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1]) - uv_code = self.neck(features_flatten, ids_restore) - batch_size, token_num, dims_feature = uv_code.shape - - if self.reshape_type=='reshape': - feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\ - self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ]) - feature_map = feature_map.permute(0, 3, 1, 4, 2, 5) # ([1, 32, 64, 4, 64, 4]) - feature_map = feature_map.reshape(batch_size, self.code_feat_dims, self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256]) - code = feature_map # [1, 32, 256, 256] - else: - feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ]) - if isinstance(self.decoder, DistributedDataParallel): - code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) - else: - code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) - - code = self.code_activation(code) - return code - - def compute_loss(self, render_out): - render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1] - target_images = render_out['target_imgs'] - target_images =target_images.to(render_images) - if self.is_debug: - render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w') - target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w') - all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2) - all_images = render_images_tmp*0.5 + target_images_tmp*0.5 - grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1)) - save_image(grid, "./debug.png") - main_print("saving into ./debug.png") - - - render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 - target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 - if self.lambda_mse<=0: - loss_mse = 0 - else: - if self.loss_weights_views.numel() != 0: - b, n, _, _, _ = render_out['image'].shape - loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device) - loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views) - main_print("weighted sum mse") - else: - loss_mse = F.mse_loss(render_images, target_images) - - if self.lambda_l1<=0: - loss_l1 = 0 - else: - loss_l1 = F.l1_loss(render_images, target_images) - - if self.lambda_ssim <= 0: - loss_ssim = 0 - else: - loss_ssim = 1 - self.ssim(render_images, target_images) - if not self.is_debug: - if self.lambda_lpips<=0: - loss_lpips = 0 - else: - if self.loss_weights_views.numel() != 0: - with torch.cuda.amp.autocast(): - loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images) - else: - loss_lpips = 0 - with torch.cuda.amp.autocast(): - for img_idx in range(render_images.shape[0]): - loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]]) - loss_lpips /= render_images.shape[0] - - else: - loss_lpips = 0 - loss_gs_offset = render_out['offset'] - loss = loss_mse * self.lambda_mse \ - + loss_l1 * self.lambda_l1 \ - + loss_ssim * self.lambda_ssim \ - + loss_lpips * self.lambda_lpips \ - + loss_gs_offset * self.lambda_offset - - prefix = 'train' - loss_dict = {} - loss_dict.update({f'{prefix}/loss_mse': loss_mse}) - loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) - loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset}) - loss_dict.update({f'{prefix}/loss_ssim': loss_ssim}) - loss_dict.update({f'{prefix}/loss_l1': loss_l1}) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - def compute_metrics(self, render_out): - # NOTE: all the rgb value range is [0, 1] - # render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs']) - render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1] - target_images = render_out['target_imgs'] - if target_images.dtype!=render_images.dtype: - target_images = target_images.to(render_images.dtype) - - render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') - target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images) - - mse = F.mse_loss(render_images, target_images).mean() - psnr = 10 * torch.log10(1.0 / mse) - ssim = self.ssim(render_images, target_images) - - render_images = render_images * 2.0 - 1.0 - target_images = target_images * 2.0 - 1.0 - - if self.lambda_lpips<=0: - lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype) - else: - with torch.cuda.amp.autocast(): - lpips = self.lpips(render_images, target_images) - - metrics = { - 'val/mse': mse, - 'val/pnsr': psnr, - 'val/ssim': ssim, - 'val/lpips': lpips, - } - return metrics - - def new_on_before_optimizer_step(self): - norms = grad_norm(self.neck, norm_type=2) - if 'grad_2.0_norm_total' in norms: - self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']}) - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - render_out = self.forward(batch) - - metrics = self.compute_metrics(render_out) - self.validation_metrics.append(metrics) - render_images = render_out['image'] - render_images = rearrange(render_images, 'b n h w c -> b c h (n w)') - gt_images = render_out['target_imgs'] - gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)') - log_images = torch.cat([render_images, gt_images], dim=-2) - self.validation_step_outputs.append(log_images) - - self.validation_step_code_outputs.append( render_out['code']) - - render_out_comb = self.forward_nvPose(batch, smplx_given=None) - self.validation_step_nvPose_outputs.append(render_out_comb) - - - def forward_nvPose(self, batch, smplx_given): - ''' - smplx_given: torch.Tensor, bs, 189 - it will returns images with cameras_num * poses_num - ''' - _, num_img, _,_ = batch['cond_poses'].shape - # write a code to seperately input the smplx_params - if smplx_given == None: - step_pose = self.smplx_params.shape[0] // num_img - smplx_given = self.smplx_params - else: - step_pose = 1 - render_out_list = [] - for i in range(num_img): - target_pose = smplx_given[[i*step_pose]] - bk = batch['cond_smpl_param'].clone() - batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose - batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw - batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression - render_out_new = self.forward(batch) - render_out_list.append(render_out_new['image']) - render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis - render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)') - return render_out_comb - - - def on_validation_epoch_end(self): # - images = torch.cat(self.validation_step_outputs, dim=-1) - all_images = self.all_gather(images).cpu() - all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') - - # nv pose - images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1) - all_images_pose = self.all_gather(images_pose).cpu() - all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w') - - if self.global_rank == 0: - image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') - - grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) - save_image(grid, image_path) - main_print(f"Saved image to {image_path}") - - metrics = {} - for key in self.validation_metrics[0].keys(): - metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() - self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True) - - - # code for saving the nvPose images - image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png') - grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1)) - save_image(grid_nvPose, image_path_nvPose) - main_print(f"Saved image to {image_path_nvPose}") - - - # code for saving the code images - for i, code in enumerate(self.validation_step_code_outputs): - image_path = os.path.join(self.logdir, 'images_val_code') - - num_scenes, num_chn, h, w = code.size() - code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() - code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) - for j, code_viz_single in enumerate(code_viz): - plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single, - vmin=self.code_clip_range[0], vmax=self.code_clip_range[1]) - self.validation_step_outputs.clear() - self.validation_step_nvPose_outputs.clear() - self.validation_metrics.clear() - self.validation_step_code_outputs.clear() - - def on_test_start(self): - if self.global_rank == 0: - os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) - - def on_test_epoch_end(self): - metrics = {} - metrics_mean = {} - metrics_var = {} - for key in self.validation_metrics[0].keys(): - tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() - metrics_mean[key] = tmp.mean() - metrics_var[key] = tmp.var() - - formatted_metrics = {} - for key in metrics_mean.keys(): - formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" - - for key in self.validation_metrics[0].keys(): - metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() - - - final_dict = {"average": formatted_metrics, - 'details': metrics} - - metric_path = os.path.join(self.logdir, f'metrics.json') - with open(metric_path, 'w') as f: - json.dump(final_dict, f, indent=4) - main_print(f"Saved metrics to {metric_path}") - - for key in metrics.keys(): - metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() - main_print(metrics) - - self.validation_metrics.clear() - - def configure_optimizers(self): - # define the optimizer and the scheduler for neck and decoder - main_print("WARNING currently, we only support the single optimizer for both neck and decoder") - - learning_rate = self.neck_learning_rate - params= [ - {'params': self.neck.parameters(), 'lr': self.neck_learning_rate, }, - {'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate}, - ] - if hasattr(self, "encoder_learning_rate") and self.encoder_learning_rate>0: - params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate}) - main_print("============add the encoder into the optimizer============") - optimizer = torch.optim.Adam( - params - ) - T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001 - lr_lambda = lambda step: \ - eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \ - eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5 - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - return {'optimizer': optimizer, 'lr_scheduler': scheduler} - - def training_step(self, batch, batch_idx): - scheduler = self.lr_schedulers() - scheduler.step() - render_gt = None #? - render_out = self.forward(batch) - loss, loss_dict = self.compute_loss(render_out) - - - self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) - - if self.global_step % 200 == 0 and self.global_rank == 0: - self.new_on_before_optimizer_step() # log the norm - if self.global_step % 200 == 0 and self.global_rank == 0: - if self.if_include_video_ref_img and self.training: - render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1) - target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1) - - target_images = rearrange( - target_images, 'b n h w c -> b c h (n w)') - render_images = rearrange( - render_images, 'b n h w c-> b c h (n w)') - - - grid = torch.cat([ - target_images, render_images, 0.5*render_images + 0.5*target_images, - - ], dim=-2) - grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) - - image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg') - save_image(grid, image_path) - main_print(f"Saved image to {image_path}") - - return loss - - @torch.no_grad() - def test_step(self, batch, batch_idx): - # input_dict, render_gt = self.prepare_validation_batch_data(batch) - render_out = self.forward(batch) - render_gt = render_out['target_imgs'] - render_img = render_out['image'] - # Compute metrics - metrics = self.compute_metrics(render_out) - self.validation_metrics.append(metrics) - - # Save images - target_images = rearrange( - render_gt, 'b n h w c -> b c h (n w)') - render_images = rearrange( - render_img, 'b n h w c -> b c h (n w)') - - - grid = torch.cat([ - target_images, render_images, - ], dim=-2) - grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) - # self.logger.log_image('train/render', [grid], step=self.global_step) - image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png') - save_image(grid, image_path) - - # code visualize - code = render_out['code'] - self.decoder.visualize(code, batch['scene_name'], - os.path.dirname(image_path), code_range=self.code_clip_range) - - print(f"Saved image to {image_path}") - - def on_test_start(self): - if self.global_rank == 0: - os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) - - def on_test_epoch_end(self): - metrics = {} - metrics_mean = {} - metrics_var = {} - for key in self.validation_metrics[0].keys(): - tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() - metrics_mean[key] = tmp.mean() - metrics_var[key] = tmp.var() - - # trans format into "mean±var" - formatted_metrics = {} - for key in metrics_mean.keys(): - formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" - - for key in self.validation_metrics[0].keys(): - metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() - - - # saving into a dictionary - final_dict = {"average": formatted_metrics, - 'details': metrics} - - metric_path = os.path.join(self.logdir, f'metrics.json') - with open(metric_path, 'w') as f: - json.dump(final_dict, f, indent=4) - print(f"Saved metrics to {metric_path}") - - for key in metrics.keys(): - metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() - print(metrics) - - self.validation_metrics.clear() - -def weighted_mse_loss(render_images, target_images, weights): - squared_diff = (render_images - target_images) ** 2 - main_print(squared_diff.shape, weights.shape) - weighted_squared_diff = squared_diff * weights - loss_mse_weighted = weighted_squared_diff.mean() + +import os +import math +import json +from torch.optim import Adam +from torch.nn.parallel.distributed import DistributedDataParallel +import torch +import torch.nn.functional as F +from torchvision.transforms import InterpolationMode +from torchvision.utils import make_grid, save_image +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure +import pytorch_lightning as pl +from pytorch_lightning.utilities.grads import grad_norm +from einops import rearrange, repeat + +from lib.utils.train_util import instantiate_from_config +from lib.ops.activation import TruncExp +import time +import matplotlib.pyplot as plt + +from PIL import Image + +import numpy as np +from lib.utils.train_util import main_print + +from typing import List, Optional, Tuple, Union +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.Tensor, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = torch.arange(pos) + theta = theta * ntk_factor + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] + t = pos # torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = freqs.to(device=t.device, dtype=t.dtype) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + elif use_real: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis +class FluxPosEmbed(torch.nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: [int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + freqs_dtype = torch.float32 if is_mps else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + +class SapiensGS_SA_v1(pl.LightningModule): + + def __init__( + self, + encoder=dict(type='mmpretrain.VisionTransformer'), + neck=dict(type='mmpretrain.VisionTransformer'), + decoder=dict(), + diffusion_use_ema=True, + freeze_decoder=False, + image_cond=False, + code_permute=None, + code_reshape=None, + autocast_dtype=None, + ortho=True, + return_norm=False, + # reshape_type='reshape', # 'cnn' + code_size=None, + decoder_use_ema=None, + bg_color=1, + training_mode=None, # stage2's flag, default None for stage1 + + patch_size: int = 4, + + warmup_steps: int = 12_000, + use_checkpoint: bool = True, + lambda_depth_tv: float = 0.05, + lambda_lpips: float = 2.0, + lambda_mse: float = 1.0, + lambda_l1: float=0, + lambda_ssim: float=0, + neck_learning_rate: float = 5e-4, + decoder_learning_rate: float = 1e-3, + encoder_learning_rate: float=0, + max_steps: int = 100_000, + loss_coef: float = 0.5, + init_iter: int = 500, + lambda_offset: int = 50, # offset_weight: 50 + scale_weight: float = 0.01, + is_debug: bool = False, # if debug, then it will not returns lpips + code_activation: dict=None, + output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder + loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views + **kwargs + ): + super(SapiensGS_SA_v1, self).__init__() + ## ========== part -- Add the code to save this parameters for optimizers ======== + self.warmup_steps = warmup_steps + self.use_checkpoint = use_checkpoint + self.lambda_depth_tv = lambda_depth_tv + self.lambda_lpips = lambda_lpips + self.lambda_mse = lambda_mse + self.lambda_l1 = lambda_l1 + self.lambda_ssim = lambda_ssim + self.neck_learning_rate = neck_learning_rate + self.decoder_learning_rate = decoder_learning_rate + self.encoder_learning_rate = encoder_learning_rate + self.max_steps = max_steps + + self.loss_coef = loss_coef + self.init_iter = init_iter + self.lambda_offset = lambda_offset + self.scale_weight = scale_weight + + self.is_debug = is_debug + ## ========== end part ======== + + + + self.code_size = code_size + if code_activation['type'] == 'tanh': + self.code_activation = torch.nn.Tanh() + else: + self.code_activation = TruncExp() #build_module(code_activation) + # self.grid_size = grid_size + self.decoder = instantiate_from_config(decoder) + self.decoder_use_ema = decoder_use_ema + if decoder_use_ema: + raise NotImplementedError("decoder_use_ema has not been implemented") + if self.decoder_use_ema: + self.decoder_ema = deepcopy(self.decoder) + self.encoder = instantiate_from_config(encoder) + # get_obj_from_str(config["target"]) + + self.code_size = code_reshape + self.code_clip_range = [-1,1] + + # ============= begin config ============= + # transformer from class MAEPretrainDecoder(BaseModule): + # compress the token number of the uv code + self.patch_size = patch_size + self.code_patch_size = self.patch_size + self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling + self.num_patches = self.num_patches_axis ** 2 + self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type + self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type + + self.reshape_type = self.decoder.reshape_type + + + self.inputs_front_only = True + self.render_loss_all_view = True + self.if_include_video_ref_img = True + + self.training_mode = training_mode + + self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views) # normalize the weights + + + + # ========== config meaning =========== + self.neck = instantiate_from_config(neck) + + self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0) + self.freeze_decoder = freeze_decoder + if self.freeze_decoder: + self.decoder.requires_grad_(False) + if self.decoder_use_ema: + self.decoder_ema.requires_grad_(False) + self.image_cond = image_cond + self.code_permute = code_permute + self.code_reshape = code_reshape + self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \ + else self.code_size + self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \ + if code_permute is not None else None + + self.autocast_dtype = autocast_dtype + self.ortho = ortho + self.return_norm = return_norm + + '''add a flag for the skip connection from sapiens shallow layer''' + self.output_hidden_states = output_hidden_states + + ''' add the in-the-wild images visualization''' + if self.lambda_lpips > 0: + self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') + else: + self.lpips = None + + self.ssim = StructuralSimilarityIndexMeasure() + + + self.validation_step_outputs = [] + self.validation_step_code_outputs = [] # saving the code + self.validation_step_nvPose_outputs = [] + self.validation_metrics = [] + + + # loading the smplx for the nv pose + import json + import numpy as np + # evaluate the animation + smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json' + with open(smplx_path, 'r') as f: + smplx_pose_param = json.load(f) + smplx_param_list = [] + for par in smplx_pose_param['annotations']: + k = par['smplx_params'] + for i in k.keys(): + k[i] = np.array(k[i]) + left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, + -0.6652, -0.7290, 0.0084, -0.4818]) + betas = torch.zeros((10)) + smplx_param = \ + np.concatenate([np.array([1]), np.array([0,0.,0]), np.array([0, -1, 0])*k['root_orient'], \ + k['pose_body'],betas, \ + k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1) + # print(smplx_param.shape) + smplx_param_list.append(smplx_param) + smplx_params = np.concatenate(smplx_param_list, 0) + self.smplx_params = torch.Tensor(smplx_params).cuda() + def get_default_smplx_params(self): + A_pose = torch.Tensor([[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.1047, 0.0000, 0.0000, -0.1047, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7854, 0.0000, + 0.0000, 0.7854, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7470, 1.0966, + 0.0169, -0.0534, -0.0212, 0.0782, -0.0348, 0.0260, 0.0060, 0.0118, + -0.1117, -0.0429, 0.4164, -0.1088, 0.0660, 0.7562, 0.0964, 0.0909, + 0.1885, 0.1181, -0.0509, 0.5296, 0.1437, -0.0552, 0.7049, 0.0192, + 0.0923, 0.3379, 0.4570, 0.1963, 0.6255, 0.2147, 0.0660, 0.5069, + 0.3697, 0.0603, 0.0795, 0.1419, 0.0859, 0.6355, 0.3033, 0.0579, + 0.6314, 0.1761, 0.1321, 0.3734, -0.8510, -0.2769, 0.0915, 0.4998, + -0.0266, -0.0529, -0.5356, -0.0460, 0.2774, -0.1117, 0.0429, -0.4164, + -0.1088, -0.0660, -0.7562, 0.0964, -0.0909, -0.1885, 0.1181, 0.0509, + -0.5296, 0.1437, 0.0552, -0.7049, 0.0192, -0.0923, -0.3379, 0.4570, + -0.1963, -0.6255, 0.2147, -0.0660, -0.5069, 0.3697, -0.0603, -0.0795, + 0.1419, -0.0859, -0.6355, 0.3033, -0.0579, -0.6314, 0.1761, -0.1321, + -0.3734, -0.8510, 0.2769, -0.0915, 0.4998, 0.0266, 0.0529, -0.5356, + 0.0460, -0.2774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) + return A_pose + def forward_decoder(self, decoder, code, target_rgbs, cameras, + smpl_params=None, return_decoder_loss=False, init=False): + decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder + num_imgs = target_rgbs.shape[1] + outputs = decoder( + code, smpl_params, cameras, + num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False) + return outputs + + def on_fit_start(self): + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True) + + + + def forward(self, data): + # print("iter") + + num_scenes = len(data['scene_id']) # 8 + if 'cond_imgs' in data: + cond_imgs = data['cond_imgs'] # (num_scenes, num_imgs, h, w, 3) + cond_intrinsics = data['cond_intrinsics'] # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy] + cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4) + smpl_params = data['cond_smpl_param'] # (num_scenes, c) + # if 'cond_norm' in data:cond_norm = data['cond_norm'] else: cond_norm = None + num_scenes, num_imgs, h, w, _ = cond_imgs.size() + cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1) + if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss + cameras = cameras[:,1:] + num_imgs = num_imgs - 1 + + + if self.inputs_front_only: # default setting to use the first image as input + inputs_img_idx = [0] + else: + raise NotImplementedError("inputs_front_only is False") + inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) # + + target_imgs = cond_imgs[:, 1:] + assert cameras.shape[1] == target_imgs.shape[1] + + + if self.is_debug: + try: + code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation + except Exception as e: # OOM + main_print(e) + code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device) + else: + code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation + + decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder + # uvmaps_decoder_gender's forward + output = decoder( + code, smpl_params, cameras, + num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset']) + + output['code'] = code + output['target_imgs'] = target_imgs + output['inputs_img'] = cond_imgs[:,[0],...] + + # for visualization + if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug: + overlay_imgs = 0.5 * target_imgs + 0.5 * output['image'] + overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c') + overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c') + overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy() + overlay_imgs = (overlay_imgs * 255).astype(np.uint8) + Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg') + + return output + + def forward_image_to_uv(self, inputs_img, is_training=True, low_ram=False): + ''' + inputs_img: torch.Tensor, bs, 3, H, W + return + code : bs, 256, 256, 32 + ''' + if self.decoder_learning_rate <= 0: + with torch.no_grad(): + features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) + else: + features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) + + if low_ram: + del self.encoder + torch.cuda.empty_cache() + + if self.ids_restore.device !=features_flatten.device: + self.ids_restore = self.ids_restore.to(features_flatten.device) + ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1]) + uv_code = self.neck(features_flatten, ids_restore) + + if low_ram: + del self.neck + torch.cuda.empty_cache() + + batch_size, token_num, dims_feature = uv_code.shape + + if self.reshape_type=='reshape': + feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\ + self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ]) + feature_map = feature_map.permute(0, 3, 1, 4, 2, 5) # ([1, 32, 64, 4, 64, 4]) + feature_map = feature_map.reshape(batch_size, self.code_feat_dims, self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256]) + code = feature_map # [1, 32, 256, 256] + else: + feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ]) + if isinstance(self.decoder, DistributedDataParallel): + code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) + else: + code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) + + code = self.code_activation(code) + return code + + def compute_loss(self, render_out): + render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1] + target_images = render_out['target_imgs'] + target_images =target_images.to(render_images) + if self.is_debug: + render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w') + target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w') + all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2) + all_images = render_images_tmp*0.5 + target_images_tmp*0.5 + grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1)) + save_image(grid, "./debug.png") + main_print("saving into ./debug.png") + + + render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 + target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 + if self.lambda_mse<=0: + loss_mse = 0 + else: + if self.loss_weights_views.numel() != 0: + b, n, _, _, _ = render_out['image'].shape + loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device) + loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views) + main_print("weighted sum mse") + else: + loss_mse = F.mse_loss(render_images, target_images) + + if self.lambda_l1<=0: + loss_l1 = 0 + else: + loss_l1 = F.l1_loss(render_images, target_images) + + if self.lambda_ssim <= 0: + loss_ssim = 0 + else: + loss_ssim = 1 - self.ssim(render_images, target_images) + if not self.is_debug: + if self.lambda_lpips<=0: + loss_lpips = 0 + else: + if self.loss_weights_views.numel() != 0: + with torch.cuda.amp.autocast(): + loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images) + else: + loss_lpips = 0 + with torch.cuda.amp.autocast(): + for img_idx in range(render_images.shape[0]): + loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]]) + loss_lpips /= render_images.shape[0] + + else: + loss_lpips = 0 + loss_gs_offset = render_out['offset'] + loss = loss_mse * self.lambda_mse \ + + loss_l1 * self.lambda_l1 \ + + loss_ssim * self.lambda_ssim \ + + loss_lpips * self.lambda_lpips \ + + loss_gs_offset * self.lambda_offset + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss_mse': loss_mse}) + loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) + loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset}) + loss_dict.update({f'{prefix}/loss_ssim': loss_ssim}) + loss_dict.update({f'{prefix}/loss_l1': loss_l1}) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def compute_metrics(self, render_out): + # NOTE: all the rgb value range is [0, 1] + # render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs']) + render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1] + target_images = render_out['target_imgs'] + if target_images.dtype!=render_images.dtype: + target_images = target_images.to(render_images.dtype) + + render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') + target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images) + + mse = F.mse_loss(render_images, target_images).mean() + psnr = 10 * torch.log10(1.0 / mse) + ssim = self.ssim(render_images, target_images) + + render_images = render_images * 2.0 - 1.0 + target_images = target_images * 2.0 - 1.0 + + if self.lambda_lpips<=0: + lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype) + else: + with torch.cuda.amp.autocast(): + lpips = self.lpips(render_images, target_images) + + metrics = { + 'val/mse': mse, + 'val/pnsr': psnr, + 'val/ssim': ssim, + 'val/lpips': lpips, + } + return metrics + + def new_on_before_optimizer_step(self): + norms = grad_norm(self.neck, norm_type=2) + if 'grad_2.0_norm_total' in norms: + self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']}) + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + render_out = self.forward(batch) + + metrics = self.compute_metrics(render_out) + self.validation_metrics.append(metrics) + render_images = render_out['image'] + render_images = rearrange(render_images, 'b n h w c -> b c h (n w)') + gt_images = render_out['target_imgs'] + gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)') + log_images = torch.cat([render_images, gt_images], dim=-2) + self.validation_step_outputs.append(log_images) + + self.validation_step_code_outputs.append( render_out['code']) + + render_out_comb = self.forward_nvPose(batch, smplx_given=None) + self.validation_step_nvPose_outputs.append(render_out_comb) + + + def forward_nvPose(self, batch, smplx_given): + ''' + smplx_given: torch.Tensor, bs, 189 + it will returns images with cameras_num * poses_num + ''' + _, num_img, _,_ = batch['cond_poses'].shape + # write a code to seperately input the smplx_params + if smplx_given == None: + step_pose = self.smplx_params.shape[0] // num_img + smplx_given = self.smplx_params + else: + step_pose = 1 + render_out_list = [] + for i in range(num_img): + target_pose = smplx_given[[i*step_pose]] + bk = batch['cond_smpl_param'].clone() + batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose + batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw + batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression + render_out_new = self.forward(batch) + render_out_list.append(render_out_new['image']) + render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis + render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)') + return render_out_comb + + + def on_validation_epoch_end(self): # + images = torch.cat(self.validation_step_outputs, dim=-1) + all_images = self.all_gather(images).cpu() + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + # nv pose + images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1) + all_images_pose = self.all_gather(images_pose).cpu() + all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') + + grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) + save_image(grid, image_path) + main_print(f"Saved image to {image_path}") + + metrics = {} + for key in self.validation_metrics[0].keys(): + metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() + self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True) + + + # code for saving the nvPose images + image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png') + grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1)) + save_image(grid_nvPose, image_path_nvPose) + main_print(f"Saved image to {image_path_nvPose}") + + + # code for saving the code images + for i, code in enumerate(self.validation_step_code_outputs): + image_path = os.path.join(self.logdir, 'images_val_code') + + num_scenes, num_chn, h, w = code.size() + code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() + code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) + for j, code_viz_single in enumerate(code_viz): + plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single, + vmin=self.code_clip_range[0], vmax=self.code_clip_range[1]) + self.validation_step_outputs.clear() + self.validation_step_nvPose_outputs.clear() + self.validation_metrics.clear() + self.validation_step_code_outputs.clear() + + def on_test_start(self): + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) + + def on_test_epoch_end(self): + metrics = {} + metrics_mean = {} + metrics_var = {} + for key in self.validation_metrics[0].keys(): + tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() + metrics_mean[key] = tmp.mean() + metrics_var[key] = tmp.var() + + formatted_metrics = {} + for key in metrics_mean.keys(): + formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" + + for key in self.validation_metrics[0].keys(): + metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() + + + final_dict = {"average": formatted_metrics, + 'details': metrics} + + metric_path = os.path.join(self.logdir, f'metrics.json') + with open(metric_path, 'w') as f: + json.dump(final_dict, f, indent=4) + main_print(f"Saved metrics to {metric_path}") + + for key in metrics.keys(): + metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() + main_print(metrics) + + self.validation_metrics.clear() + + def configure_optimizers(self): + # define the optimizer and the scheduler for neck and decoder + main_print("WARNING currently, we only support the single optimizer for both neck and decoder") + + learning_rate = self.neck_learning_rate + params= [ + {'params': self.neck.parameters(), 'lr': self.neck_learning_rate, }, + {'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate}, + ] + if hasattr(self, "encoder_learning_rate") and self.encoder_learning_rate>0: + params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate}) + main_print("============add the encoder into the optimizer============") + optimizer = torch.optim.Adam( + params + ) + T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001 + lr_lambda = lambda step: \ + eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \ + eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5 + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + return {'optimizer': optimizer, 'lr_scheduler': scheduler} + + def training_step(self, batch, batch_idx): + scheduler = self.lr_schedulers() + scheduler.step() + render_gt = None #? + render_out = self.forward(batch) + loss, loss_dict = self.compute_loss(render_out) + + + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + + if self.global_step % 200 == 0 and self.global_rank == 0: + self.new_on_before_optimizer_step() # log the norm + if self.global_step % 200 == 0 and self.global_rank == 0: + if self.if_include_video_ref_img and self.training: + render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1) + target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1) + + target_images = rearrange( + target_images, 'b n h w c -> b c h (n w)') + render_images = rearrange( + render_images, 'b n h w c-> b c h (n w)') + + + grid = torch.cat([ + target_images, render_images, 0.5*render_images + 0.5*target_images, + + ], dim=-2) + grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) + + image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg') + save_image(grid, image_path) + main_print(f"Saved image to {image_path}") + + return loss + + @torch.no_grad() + def test_step(self, batch, batch_idx): + # input_dict, render_gt = self.prepare_validation_batch_data(batch) + render_out = self.forward(batch) + render_gt = render_out['target_imgs'] + render_img = render_out['image'] + # Compute metrics + metrics = self.compute_metrics(render_out) + self.validation_metrics.append(metrics) + + # Save images + target_images = rearrange( + render_gt, 'b n h w c -> b c h (n w)') + render_images = rearrange( + render_img, 'b n h w c -> b c h (n w)') + + + grid = torch.cat([ + target_images, render_images, + ], dim=-2) + grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) + # self.logger.log_image('train/render', [grid], step=self.global_step) + image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png') + save_image(grid, image_path) + + # code visualize + code = render_out['code'] + self.decoder.visualize(code, batch['scene_name'], + os.path.dirname(image_path), code_range=self.code_clip_range) + + print(f"Saved image to {image_path}") + + def on_test_start(self): + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) + + def on_test_epoch_end(self): + metrics = {} + metrics_mean = {} + metrics_var = {} + for key in self.validation_metrics[0].keys(): + tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() + metrics_mean[key] = tmp.mean() + metrics_var[key] = tmp.var() + + # trans format into "mean±var" + formatted_metrics = {} + for key in metrics_mean.keys(): + formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" + + for key in self.validation_metrics[0].keys(): + metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() + + + # saving into a dictionary + final_dict = {"average": formatted_metrics, + 'details': metrics} + + metric_path = os.path.join(self.logdir, f'metrics.json') + with open(metric_path, 'w') as f: + json.dump(final_dict, f, indent=4) + print(f"Saved metrics to {metric_path}") + + for key in metrics.keys(): + metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() + print(metrics) + + self.validation_metrics.clear() + +def weighted_mse_loss(render_images, target_images, weights): + squared_diff = (render_images - target_images) ** 2 + main_print(squared_diff.shape, weights.shape) + weighted_squared_diff = squared_diff * weights + loss_mse_weighted = weighted_squared_diff.mean() return loss_mse_weighted \ No newline at end of file diff --git a/lib/models/decoders/uvmaps_decoder_gender.py b/lib/models/decoders/uvmaps_decoder_gender.py index f0cb33e..7ebd34b 100644 --- a/lib/models/decoders/uvmaps_decoder_gender.py +++ b/lib/models/decoders/uvmaps_decoder_gender.py @@ -1,679 +1,700 @@ -import os -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import einsum -from pytorch3d import ops - -from lib.mmutils import xavier_init, constant_init -import numpy as np -import time -import cv2 -import math -from simple_knn._C import distCUDA2 -from pytorch3d.transforms import quaternion_to_matrix - -from ..deformers import SMPLXDeformer_gender -from ..renderers import GRenderer, get_covariance, batch_rodrigues -from lib.ops import TruncExp -import torchvision - -from lib.utils.train_util import main_print - -def ensure_dtype(input_tensor, target_dtype=torch.float32): - """ - Ensure tensor dtype matches target dtype. - If not, convert it. - """ - if input_tensor.dtype != target_dtype: - input_tensor = input_tensor.to(dtype=target_dtype) - return input_tensor - - -class UVNDecoder_gender(nn.Module): - - activation_dict = { - 'relu': nn.ReLU, - 'silu': nn.SiLU, - 'softplus': nn.Softplus, - 'trunc_exp': TruncExp, - 'sigmoid': nn.Sigmoid} - - def __init__(self, - *args, - interp_mode='bilinear', - base_layers=[3 * 32, 128], - density_layers=[128, 1], - color_layers=[128, 128, 3], - offset_layers=[128, 3], - scale_layers=[128, 3], - radius_layers=[128, 3], - use_dir_enc=True, - dir_layers=None, - scene_base_size=None, - scene_rand_dims=(0, 1), - activation='silu', - sigma_activation='sigmoid', - sigmoid_saturation=0.001, - code_dropout=0.0, - flip_z=False, - extend_z=False, - gender='neutral', - multires=0, - bg_color=0, - image_size=1024, - superres=False, - focal = 1280, # the default focal defination - reshape_type=None, # if true, it will create a cnn layers to upsample the uv features - fix_sigma=False, # if true, the density of GS will be fixed - up_cnn_in_channels = None, # the channel number of the upsample cnn - vithead_param=None, # the vit head for decode to uv features - is_sub2=False, # if true, will use the sub2 uv map - **kwargs): - super().__init__() - self.interp_mode = interp_mode - self.in_chn = base_layers[0] - self.use_dir_enc = use_dir_enc - if scene_base_size is None: - self.scene_base = None - else: - rand_size = [1 for _ in scene_base_size] - for dim in scene_rand_dims: - rand_size[dim] = scene_base_size[dim] - init_base = torch.randn(rand_size).expand(scene_base_size).clone() - self.scene_base = nn.Parameter(init_base) - self.dir_encoder = None - self.sigmoid_saturation = sigmoid_saturation - self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2) - - self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal) - if superres: - self.superres = None - else: - self.superres = None - self.gender= gender - self.reshape_type = reshape_type - if reshape_type=='cnn': - self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda() - - elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm - from lib.models.decoders.vit_head import VitHead - self.upsample_conv = VitHead(**vithead_param) - # 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024 - - base_cache_dir = 'work_dirs/cache' - if is_sub2: - base_cache_dir = 'work_dirs/cache_sub2' - # main_print("!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!") - if gender == 'neutral': - select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy')) - self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) - - init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy')) - self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 - elif gender == 'male': - assert NotImplementedError("Haven't create the init_uv_smplx_thu in v_template") - select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy')) - self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) - - init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy')) - self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 - self.num_init = self.init_pcd.shape[1] - main_print(f"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!") - - self.init_pcd = self.init_pcd - - self.multires = multires # 0 Haven't - if multires > 0: - uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy')) - pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy')) - input_coord = torch.cat([pcd_map, uv_map], dim=1) - self.register_buffer('input_freq', input_coord, persistent=False) - base_layers[0] += 5 - color_layers[0] += 5 - else: - self.init_uv = None - - activation_layer = self.activation_dict[activation.lower()] - - - base_net = [] # linear (in=18, out=64, bias=True) - for i in range(len(base_layers) - 1): - base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1)) - if i != len(base_layers) - 2: - base_net.append(nn.BatchNorm2d(base_layers[i+1])) - base_net.append(activation_layer()) - self.base_net = nn.Sequential(*base_net) - self.base_bn = nn.BatchNorm2d(base_layers[-1]) - self.base_activation = activation_layer() - - density_net = [] # linear(in=64, out=1, bias=True), sigmoid - for i in range(len(density_layers) - 1): - density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1)) - if i != len(density_layers) - 2: - density_net.append(nn.BatchNorm2d(density_layers[i+1])) - density_net.append(activation_layer()) - density_net.append(self.activation_dict[sigma_activation.lower()]()) - self.density_net = nn.Sequential(*density_net) - - offset_net = [] # linear(in=64, out=1, bias=True), sigmoid - for i in range(len(offset_layers) - 1): - offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1)) - if i != len(offset_layers) - 2: - offset_net.append(nn.BatchNorm2d(offset_layers[i+1])) - offset_net.append(activation_layer()) - self.offset_net = nn.Sequential(*offset_net) - - self.dir_net = None - color_net = [] # linear(in=64, out=3, bias=True), sigmoid - for i in range(len(color_layers) - 2): - color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1)) - color_net.append(nn.BatchNorm2d(color_layers[i+1])) - color_net.append(activation_layer()) - color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1)) - color_net.append(nn.Sigmoid()) - self.color_net = nn.Sequential(*color_net) - self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None - - self.flip_z = flip_z - self.extend_z = extend_z - - if self.gender == 'neutral': - init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy')) - self.register_buffer('init_rot', init_rot, persistent=False) - - face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy')) - self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) - - hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy')) - self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) - - outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy')) - self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) - else: - assert NotImplementedError("Haven't create the init_rot in v_template") - init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy')) - self.register_buffer('init_rot', init_rot, persistent=False) - - face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy')) - self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) - - hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy')) - self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) - - outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy')) - self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) - - self.iter = 0 - self.init_weights() - self.if_rotate_gaussian = False - self.fix_sigma = fix_sigma - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - xavier_init(m, distribution='uniform') - if self.dir_net is not None: - constant_init(self.dir_net[-1], 0) - if self.offset_net is not None: - self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5) - self.offset_net[-1].bias.data.zero_() - - - def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False): - ''' - Args: - B == num_scenes - code (tensor): latent code. shape: [B, C, H, W] - smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] - init (bool): Not used - Returns: - defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] - sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] - tfs(tensor): deformation matrics. shape: [B, N, C] - ''' - if isinstance(code, list): - num_scenes, _, h, w = code[0].size() - else: - num_scenes, n_channels, h, w = code.size() - init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) # T-posed space points, for computing the skinning weights - - sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) # the person-specify attributes of GS - if self.fix_sigma: - sigmas = torch.ones_like(sigmas) - if zeros_hands_off: - offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0 - canon_pcd = init_pcd + offset - - self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) - defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) - return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot - - def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1): - ''' - Args: - B == num_scenes - code (List): list of data - smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] - init (bool): Not used - Returns: - defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] - sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] - tfs(tensor): deformation matrics. shape: [B, N, C] - ''' - sigmas, rgbs, radius, rot, offset = code - num_scenes = sigmas.shape[0] - init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) #T-posed space points, for computing the skinning weights - - if self.fix_sigma: - sigmas = torch.ones_like(sigmas) - if zeros_hands_off: - offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value) - canon_pcd = init_pcd + offset - self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) - defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) - return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot - - - - def _sample_feature(self,results,): - # outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot'] - sigma = results['sigma'] - outputs = results['output'] - if isinstance(sigma, list): - num_scenes, _, h, w = sigma[0].shape - select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - elif sigma.dim() == 4: - num_scenes, n_channels, h, w = sigma.shape - select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - else: - assert False - output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) - sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) - - if self.sigmoid_saturation > 0: - rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation - - radius = (radius - 0.5) * 2 - rot = (rot - 0.5) * np.pi - - return sigma, rgbs, radius, rot, offset - - def _decode_feature(self, point_code, init=False): - if isinstance(point_code, list): - num_scenes, _, h, w = point_code[0].shape - geo_code, tex_code = point_code - # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - if self.multires != 0: - input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) - elif point_code.dim() == 4: - num_scenes, n_channels, h, w = point_code.shape - # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - if self.multires != 0: - input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) - geo_code, tex_code = point_code.split(16, dim=1) - else: - assert False - - base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) - base_x = self.base_net(base_in) - base_x_act = self.base_activation(self.base_bn(base_x)) - - sigma = self.density_net(base_x_act) - offset = self.offset_net(base_x_act) - color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) - rgbs_radius_rot = self.color_net(color_in) - - outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) - main_print(outputs.shape) - sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1) - results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot} - - return results - def _decode(self, point_code, init=False): - if isinstance(point_code, list): - num_scenes, _, h, w = point_code[0].shape - geo_code, tex_code = point_code - select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - if self.multires != 0: - input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) - elif point_code.dim() == 4: - num_scenes, n_channels, h, w = point_code.shape - select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) - if self.multires != 0: - input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) - geo_code, tex_code = point_code.split(16, dim=1) - else: - assert False - - base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) - base_x = self.base_net(base_in) - base_x_act = self.base_activation(self.base_bn(base_x)) - - sigma = self.density_net(base_x_act) - offset = self.offset_net(base_x_act) - color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) - rgbs_radius_rot = self.color_net(color_in) - - outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) - output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) - sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) - - if self.sigmoid_saturation > 0: - rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation - - radius = (radius - 0.5) * 2 - rot = (rot - 0.5) * np.pi - - return sigma, rgbs, radius, rot, offset - - def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \ - return_norm=False, return_viz=False, mask=None): - # add mask or visible points to images or select ind to images - ''' - render the gaussian to images - return_norm: return the normals of the gaussian (haven't been used) - return_viz: return the mask of the gaussian - mask: the mask of the gaussian - ''' - assert num_scenes == 1 - - pcd = pcd.reshape(-1, 3) - if use_scale: - dist2 = distCUDA2(pcd) - dist2 = torch.clamp_min((dist2), 0.0000001) - scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points - scale = (radius+1)*scales # scaling_modifier # radius[-1--1], scale of GS - cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations - - images_all = [] - viz_masks = [] if return_viz else None - norm_all = [] if return_norm else None - - if mask != None: - pcd = pcd[mask] - rgbs = rgbs[mask] - sigmas = sigmas[mask] - cov3D = cov3D[mask] - normals = normals[mask] - if 1: - for i in range(num_imgs): - self.renderer.prepare(cameras[i]) - - image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs, - rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D) - images_all.append(image) - if return_viz: - viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(), - rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D) - viz_masks.append(viz_mask) - - - images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2) - if return_viz: - viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3) - dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True) - viz_masks = (dist_sq < 0.0001)[0] - # ===== END the original code for batch size = 1 ===== - if use_scale: - return images_all, norm_all, viz_masks, scale - else: - return images_all, norm_all, viz_masks, None - - def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]): - num_scenes, num_chn, h, w = code.size() - code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() - if not self.flip_z: - code_viz = code_viz[..., ::-1, :] - code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) - for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name): - plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single, - vmin=code_range[0], vmax=code_range[1]) - - def forward(self, code, smpl_params, cameras, num_imgs, - return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): - """ - Args: - - - density_bitfield: Shape (num_scenes, griz_size**3 // 8) - YY: - grid_size, dt_gamma, perturb, T_thresh are deleted - code: Shape (num_scenes, *code_size) - cameras: Shape (num_scenes, num_imgs, 19(3+16)) - smpl_params: Shape (num_scenes, 189) - - """ - # import ipdb; ipdb.set_trace() - if isinstance(code, list): - num_scenes = len(code[0]) - else: - num_scenes = len(code) - assert num_scenes > 0 - self.iter+=1 - - image = [] - scales = [] - norm = [] if return_norm else None - viz_masks = [] if not self.training else None - - xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) - - if zeros_hands_off: - main_print('zeros_hands_off is on!') - main_print('zeros_hands_off is on!') - offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - R_delta = batch_rodrigues(rot.reshape(-1, 3)) - R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) - R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) - normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) - R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) - - return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= - # return_to_bfloat16 = False # I don't want to trans it back to bf16 - if return_to_bfloat16: - main_print("changes the return_to_bfloat16") - cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] - # with torch.amp.autocast(enabled=False, device_type='cuda'): - if 1: - for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): - image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ - radius=radius_single, return_norm=return_norm, return_viz=not self.training) - image.append(image_single) - scales.append(scale.unsqueeze(0)) - if return_norm: - norm.append(norm_single) - if not self.training: - viz_masks.append(viz_mask) - image = torch.cat(image, dim=0) - scales = torch.cat(scales, dim=0) - - norm = torch.cat(norm, dim=0) if return_norm else None - viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None - - - main_print("not trans the rendered results to float16") - if False: - image = image.to(torch.bfloat16) - scales = scales.to(torch.bfloat16) - if return_norm: - norm = norm.to(torch.bfloat16) - if viz_masks is not None: - viz_masks = viz_masks.to(torch.bfloat16) - offsets = offsets.to(torch.bfloat16) - - if self.training: - offset_dist = offsets ** 2 - weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) - else: - weighted_offset = offsets - - - results = dict( - viz_masks=viz_masks, - scales=scales, - norm=norm, - image=image, - offset=weighted_offset) - - if return_loss: - results.update(decoder_reg_loss=self.loss()) - - return results - - def forward_render(self, code, cameras, num_imgs, - return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): - """ - Args: - - - density_bitfield: Shape (num_scenes, griz_size**3 // 8) - YY: - grid_size, dt_gamma, perturb, T_thresh are deleted - code: Shape (num_scenes, *code_size) - cameras: Shape (num_scenes, num_imgs, 19(3+16)) - smpl_params: Shape (num_scenes, 189) - - """ - image = [] - scales = [] - norm = [] if return_norm else None - viz_masks = [] if not self.training else None - - xyzs, sigmas, rgbs, offsets, radius, tfs, rot = code - num_scenes = xyzs.shape[0] - if zeros_hands_off: - main_print('zeros_hands_off is on!') - main_print('zeros_hands_off is on!') - offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 - R_delta = batch_rodrigues(rot.reshape(-1, 3)) - R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) - R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) - normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) - R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) - # import ipdb; ipdb.set_trace() - - return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= - # return_to_bfloat16 = False # I don't want to trans it back to bf16 - if return_to_bfloat16: - main_print("changes the return_to_bfloat16") - cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] - - if 1: - for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): - image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ - radius=radius_single, return_norm=return_norm, return_viz=not self.training) - image.append(image_single) - scales.append(scale.unsqueeze(0)) - if return_norm: - norm.append(norm_single) - if not self.training: - viz_masks.append(viz_mask) - image = torch.cat(image, dim=0) - scales = torch.cat(scales, dim=0) - - norm = torch.cat(norm, dim=0) if return_norm else None - viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None - - - main_print("not trans the rendered results to float16") - if False: - image = image.to(torch.bfloat16) - scales = scales.to(torch.bfloat16) - if return_norm: - norm = norm.to(torch.bfloat16) - if viz_masks is not None: - viz_masks = viz_masks.to(torch.bfloat16) - offsets = offsets.to(torch.bfloat16) - - if self.training: - offset_dist = offsets ** 2 - weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) - else: - weighted_offset = offsets - - - results = dict( - viz_masks=viz_masks, - scales=scales, - norm=norm, - image=image, - offset=weighted_offset) - if return_loss: - results.update(decoder_reg_loss=self.loss()) - - return results - - - def forward_testing_time(self, code, smpl_params, cameras, num_imgs, - return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): - """ - Args: - - - density_bitfield: Shape (num_scenes, griz_size**3 // 8) - YY: - grid_size, dt_gamma, perturb, T_thresh are deleted - code: Shape (num_scenes, *code_size) - cameras: Shape (num_scenes, num_imgs, 19(3+16)) - smpl_params: Shape (num_scenes, 189) - - """ - if isinstance(code, list): - num_scenes = len(code[0]) - else: - num_scenes = len(code) - assert num_scenes > 0 - self.iter+=1 - - image = [] - scales = [] - norm = [] if return_norm else None - viz_masks = [] if not self.training else None - start_time = time.time() - xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) - end_time_to_3D = time.time() - time_code_to_3d = end_time_to_3D- start_time - - R_delta = batch_rodrigues(rot.reshape(-1, 3)) - R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) - R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) - normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) - R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) - if 1: - for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): - image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ - radius=radius_single, return_norm=False, return_viz=not self.training) - image.append(image_single) - scales.append(scale.unsqueeze(0)) - if return_norm: - norm.append(norm_single) - if not self.training: - viz_masks.append(viz_mask) - image = torch.cat(image, dim=0) - scales = torch.cat(scales, dim=0) - - norm = torch.cat(norm, dim=0) if return_norm else None - viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None - - time_3D_to_img = time.time() - end_time_to_3D - - - if False: - image = image.to(torch.bfloat16) - scales = scales.to(torch.bfloat16) - if return_norm: - norm = norm.to(torch.bfloat16) - if viz_masks is not None: - viz_masks = viz_masks.to(torch.bfloat16) - offsets = offsets.to(torch.bfloat16) - - results = dict( - image=image) - +import os +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +from pytorch3d import ops + +from lib.mmutils import xavier_init, constant_init +import numpy as np +import time +import cv2 +import math +from simple_knn._C import distCUDA2 +from pytorch3d.transforms import quaternion_to_matrix + +from ..deformers import SMPLXDeformer_gender +from ..renderers import GRenderer, get_covariance, batch_rodrigues +from lib.ops import TruncExp +import torchvision + +from lib.utils.train_util import main_print + +def ensure_dtype(input_tensor, target_dtype=torch.float32): + """ + Ensure tensor dtype matches target dtype. + If not, convert it. + """ + if input_tensor.dtype != target_dtype: + input_tensor = input_tensor.to(dtype=target_dtype) + return input_tensor + + +class UVNDecoder_gender(nn.Module): + + activation_dict = { + 'relu': nn.ReLU, + 'silu': nn.SiLU, + 'softplus': nn.Softplus, + 'trunc_exp': TruncExp, + 'sigmoid': nn.Sigmoid} + + def __init__(self, + *args, + interp_mode='bilinear', + base_layers=[3 * 32, 128], + density_layers=[128, 1], + color_layers=[128, 128, 3], + offset_layers=[128, 3], + scale_layers=[128, 3], + radius_layers=[128, 3], + use_dir_enc=True, + dir_layers=None, + scene_base_size=None, + scene_rand_dims=(0, 1), + activation='silu', + sigma_activation='sigmoid', + sigmoid_saturation=0.001, + code_dropout=0.0, + flip_z=False, + extend_z=False, + gender='neutral', + multires=0, + bg_color=0, + image_size=1024, + superres=False, + focal = 1280, # the default focal defination + reshape_type=None, # if true, it will create a cnn layers to upsample the uv features + fix_sigma=False, # if true, the density of GS will be fixed + up_cnn_in_channels = None, # the channel number of the upsample cnn + vithead_param=None, # the vit head for decode to uv features + is_sub2=False, # if true, will use the sub2 uv map + **kwargs): + super().__init__() + self.interp_mode = interp_mode + self.in_chn = base_layers[0] + self.use_dir_enc = use_dir_enc + if scene_base_size is None: + self.scene_base = None + else: + rand_size = [1 for _ in scene_base_size] + for dim in scene_rand_dims: + rand_size[dim] = scene_base_size[dim] + init_base = torch.randn(rand_size).expand(scene_base_size).clone() + self.scene_base = nn.Parameter(init_base) + self.dir_encoder = None + self.sigmoid_saturation = sigmoid_saturation + self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2) + + self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal) + if superres: + self.superres = None + else: + self.superres = None + self.gender= gender + self.reshape_type = reshape_type + if reshape_type=='cnn': + self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda() + + elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm + from lib.models.decoders.vit_head import VitHead + self.upsample_conv = VitHead(**vithead_param) + # 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024 + + base_cache_dir = 'work_dirs/cache' + if is_sub2: + base_cache_dir = 'work_dirs/cache_sub2' + # main_print("!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!") + if gender == 'neutral': + select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy')) + self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) + + init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy')) + self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 + elif gender == 'male': + assert NotImplementedError("Haven't create the init_uv_smplx_thu in v_template") + select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy')) + self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) + + init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy')) + self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 + self.num_init = self.init_pcd.shape[1] + main_print(f"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!") + + self.init_pcd = self.init_pcd + + self.multires = multires # 0 Haven't + if multires > 0: + uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy')) + pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy')) + input_coord = torch.cat([pcd_map, uv_map], dim=1) + self.register_buffer('input_freq', input_coord, persistent=False) + base_layers[0] += 5 + color_layers[0] += 5 + else: + self.init_uv = None + + activation_layer = self.activation_dict[activation.lower()] + + + base_net = [] # linear (in=18, out=64, bias=True) + for i in range(len(base_layers) - 1): + base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1)) + if i != len(base_layers) - 2: + base_net.append(nn.BatchNorm2d(base_layers[i+1])) + base_net.append(activation_layer()) + self.base_net = nn.Sequential(*base_net) + self.base_bn = nn.BatchNorm2d(base_layers[-1]) + self.base_activation = activation_layer() + + density_net = [] # linear(in=64, out=1, bias=True), sigmoid + for i in range(len(density_layers) - 1): + density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1)) + if i != len(density_layers) - 2: + density_net.append(nn.BatchNorm2d(density_layers[i+1])) + density_net.append(activation_layer()) + density_net.append(self.activation_dict[sigma_activation.lower()]()) + self.density_net = nn.Sequential(*density_net) + + offset_net = [] # linear(in=64, out=1, bias=True), sigmoid + for i in range(len(offset_layers) - 1): + offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1)) + if i != len(offset_layers) - 2: + offset_net.append(nn.BatchNorm2d(offset_layers[i+1])) + offset_net.append(activation_layer()) + self.offset_net = nn.Sequential(*offset_net) + + self.dir_net = None + color_net = [] # linear(in=64, out=3, bias=True), sigmoid + for i in range(len(color_layers) - 2): + color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1)) + color_net.append(nn.BatchNorm2d(color_layers[i+1])) + color_net.append(activation_layer()) + color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1)) + color_net.append(nn.Sigmoid()) + self.color_net = nn.Sequential(*color_net) + self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None + + self.flip_z = flip_z + self.extend_z = extend_z + + if self.gender == 'neutral': + init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy')) + self.register_buffer('init_rot', init_rot, persistent=False) + + face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy')) + self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) + + hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy')) + self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) + + outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy')) + self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) + else: + assert NotImplementedError("Haven't create the init_rot in v_template") + init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy')) + self.register_buffer('init_rot', init_rot, persistent=False) + + face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy')) + self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) + + hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy')) + self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) + + outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy')) + self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) + + self.iter = 0 + # self.init_weights() + self.if_rotate_gaussian = False + self.fix_sigma = fix_sigma + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + xavier_init(m, distribution='uniform') + if self.dir_net is not None: + constant_init(self.dir_net[-1], 0) + if self.offset_net is not None: + self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5) + self.offset_net[-1].bias.data.zero_() + + + def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False): + ''' + Args: + B == num_scenes + code (tensor): latent code. shape: [B, C, H, W] + smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] + init (bool): Not used + Returns: + defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] + sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] + tfs(tensor): deformation matrics. shape: [B, N, C] + ''' + if isinstance(code, list): + num_scenes, _, h, w = code[0].size() + else: + num_scenes, n_channels, h, w = code.size() + init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) # T-posed space points, for computing the skinning weights + + sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) # the person-specify attributes of GS + if self.fix_sigma: + sigmas = torch.ones_like(sigmas) + if zeros_hands_off: + offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0 + canon_pcd = init_pcd + offset + + self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) + defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) + return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot + + def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1): + ''' + Args: + B == num_scenes + code (List): list of data + smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] + init (bool): Not used + Returns: + defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] + sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] + tfs(tensor): deformation matrics. shape: [B, N, C] + ''' + sigmas, rgbs, radius, rot, offset = code + num_scenes = sigmas.shape[0] + init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) #T-posed space points, for computing the skinning weights + + if self.fix_sigma: + sigmas = torch.ones_like(sigmas) + if zeros_hands_off: + offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value) + canon_pcd = init_pcd + offset + self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) + defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) + return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot + + + + def _sample_feature(self,results,): + # outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot'] + sigma = results['sigma'] + outputs = results['output'] + if isinstance(sigma, list): + num_scenes, _, h, w = sigma[0].shape + select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + elif sigma.dim() == 4: + num_scenes, n_channels, h, w = sigma.shape + select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + else: + assert False + output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) + sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) + + if self.sigmoid_saturation > 0: + rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation + + radius = (radius - 0.5) * 2 + rot = (rot - 0.5) * np.pi + + return sigma, rgbs, radius, rot, offset + + def _decode_feature(self, point_code, low_ram=False): + if isinstance(point_code, list): + num_scenes, _, h, w = point_code[0].shape + geo_code, tex_code = point_code + # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + if self.multires != 0: + input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) + elif point_code.dim() == 4: + num_scenes, n_channels, h, w = point_code.shape + # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + if self.multires != 0: + input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) + geo_code, tex_code = point_code.split(16, dim=1) + else: + assert False + + base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) + base_x = self.base_net(base_in) + base_bn = self.base_bn(base_x) + + if low_ram: + del base_x + torch.cuda.empty_cache() + + base_x_act = self.base_activation(base_bn) + + if low_ram: + del base_bn + torch.cuda.empty_cache() + + sigma = self.density_net(base_x_act) + offset = self.offset_net(base_x_act) + + if low_ram: + del base_x_act + torch.cuda.empty_cache() + + color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) + rgbs_radius_rot = self.color_net(color_in) + + outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) + + if low_ram: + del color_in + del rgbs_radius_rot + torch.cuda.empty_cache() + + main_print(outputs.shape) + sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1) + results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot} + + return results + def _decode(self, point_code, init=False): + if isinstance(point_code, list): + num_scenes, _, h, w = point_code[0].shape + geo_code, tex_code = point_code + select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + if self.multires != 0: + input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) + elif point_code.dim() == 4: + num_scenes, n_channels, h, w = point_code.shape + select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) + if self.multires != 0: + input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) + geo_code, tex_code = point_code.split(16, dim=1) + else: + assert False + + base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) + base_x = self.base_net(base_in) + base_x_act = self.base_activation(self.base_bn(base_x)) + + sigma = self.density_net(base_x_act) + offset = self.offset_net(base_x_act) + color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) + rgbs_radius_rot = self.color_net(color_in) + + outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) + output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) + sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) + + if self.sigmoid_saturation > 0: + rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation + + radius = (radius - 0.5) * 2 + rot = (rot - 0.5) * np.pi + + return sigma, rgbs, radius, rot, offset + + def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \ + return_norm=False, return_viz=False, mask=None): + # add mask or visible points to images or select ind to images + ''' + render the gaussian to images + return_norm: return the normals of the gaussian (haven't been used) + return_viz: return the mask of the gaussian + mask: the mask of the gaussian + ''' + assert num_scenes == 1 + + pcd = pcd.reshape(-1, 3) + if use_scale: + dist2 = distCUDA2(pcd) + dist2 = torch.clamp_min((dist2), 0.0000001) + scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points + scale = (radius+1)*scales # scaling_modifier # radius[-1--1], scale of GS + cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations + + images_all = [] + viz_masks = [] if return_viz else None + norm_all = [] if return_norm else None + + if mask != None: + pcd = pcd[mask] + rgbs = rgbs[mask] + sigmas = sigmas[mask] + cov3D = cov3D[mask] + normals = normals[mask] + if 1: + for i in range(num_imgs): + self.renderer.prepare(cameras[i]) + + image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs, + rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D) + images_all.append(image) + if return_viz: + viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(), + rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D) + viz_masks.append(viz_mask) + + + images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2) + if return_viz: + viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3) + dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True) + viz_masks = (dist_sq < 0.0001)[0] + # ===== END the original code for batch size = 1 ===== + if use_scale: + return images_all, norm_all, viz_masks, scale + else: + return images_all, norm_all, viz_masks, None + + def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]): + num_scenes, num_chn, h, w = code.size() + code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() + if not self.flip_z: + code_viz = code_viz[..., ::-1, :] + code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) + for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name): + plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single, + vmin=code_range[0], vmax=code_range[1]) + + def forward(self, code, smpl_params, cameras, num_imgs, + return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): + """ + Args: + + + density_bitfield: Shape (num_scenes, griz_size**3 // 8) + YY: + grid_size, dt_gamma, perturb, T_thresh are deleted + code: Shape (num_scenes, *code_size) + cameras: Shape (num_scenes, num_imgs, 19(3+16)) + smpl_params: Shape (num_scenes, 189) + + """ + # import ipdb; ipdb.set_trace() + if isinstance(code, list): + num_scenes = len(code[0]) + else: + num_scenes = len(code) + assert num_scenes > 0 + self.iter+=1 + + image = [] + scales = [] + norm = [] if return_norm else None + viz_masks = [] if not self.training else None + + xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) + + if zeros_hands_off: + main_print('zeros_hands_off is on!') + main_print('zeros_hands_off is on!') + offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + R_delta = batch_rodrigues(rot.reshape(-1, 3)) + R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) + R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) + normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) + R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) + + return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= + # return_to_bfloat16 = False # I don't want to trans it back to bf16 + if return_to_bfloat16: + main_print("changes the return_to_bfloat16") + cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] + # with torch.amp.autocast(enabled=False, device_type='cuda'): + if 1: + for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): + image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ + radius=radius_single, return_norm=return_norm, return_viz=not self.training) + image.append(image_single) + scales.append(scale.unsqueeze(0)) + if return_norm: + norm.append(norm_single) + if not self.training: + viz_masks.append(viz_mask) + image = torch.cat(image, dim=0) + scales = torch.cat(scales, dim=0) + + norm = torch.cat(norm, dim=0) if return_norm else None + viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None + + + main_print("not trans the rendered results to float16") + if False: + image = image.to(torch.bfloat16) + scales = scales.to(torch.bfloat16) + if return_norm: + norm = norm.to(torch.bfloat16) + if viz_masks is not None: + viz_masks = viz_masks.to(torch.bfloat16) + offsets = offsets.to(torch.bfloat16) + + if self.training: + offset_dist = offsets ** 2 + weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) + else: + weighted_offset = offsets + + + results = dict( + viz_masks=viz_masks, + scales=scales, + norm=norm, + image=image, + offset=weighted_offset) + + if return_loss: + results.update(decoder_reg_loss=self.loss()) + + return results + + def forward_render(self, code, cameras, num_imgs, + return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): + """ + Args: + + + density_bitfield: Shape (num_scenes, griz_size**3 // 8) + YY: + grid_size, dt_gamma, perturb, T_thresh are deleted + code: Shape (num_scenes, *code_size) + cameras: Shape (num_scenes, num_imgs, 19(3+16)) + smpl_params: Shape (num_scenes, 189) + + """ + image = [] + scales = [] + norm = [] if return_norm else None + viz_masks = [] if not self.training else None + + xyzs, sigmas, rgbs, offsets, radius, tfs, rot = code + num_scenes = xyzs.shape[0] + if zeros_hands_off: + main_print('zeros_hands_off is on!') + main_print('zeros_hands_off is on!') + offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 + R_delta = batch_rodrigues(rot.reshape(-1, 3)) + R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) + R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) + normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) + R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) + # import ipdb; ipdb.set_trace() + + return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= + # return_to_bfloat16 = False # I don't want to trans it back to bf16 + if return_to_bfloat16: + main_print("changes the return_to_bfloat16") + cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] + + if 1: + for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): + image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ + radius=radius_single, return_norm=return_norm, return_viz=not self.training) + image.append(image_single) + scales.append(scale.unsqueeze(0)) + if return_norm: + norm.append(norm_single) + if not self.training: + viz_masks.append(viz_mask) + image = torch.cat(image, dim=0) + scales = torch.cat(scales, dim=0) + + norm = torch.cat(norm, dim=0) if return_norm else None + viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None + + + main_print("not trans the rendered results to float16") + if False: + image = image.to(torch.bfloat16) + scales = scales.to(torch.bfloat16) + if return_norm: + norm = norm.to(torch.bfloat16) + if viz_masks is not None: + viz_masks = viz_masks.to(torch.bfloat16) + offsets = offsets.to(torch.bfloat16) + + if self.training: + offset_dist = offsets ** 2 + weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) + else: + weighted_offset = offsets + + + results = dict( + viz_masks=viz_masks, + scales=scales, + norm=norm, + image=image, + offset=weighted_offset) + if return_loss: + results.update(decoder_reg_loss=self.loss()) + + return results + + + def forward_testing_time(self, code, smpl_params, cameras, num_imgs, + return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): + """ + Args: + + + density_bitfield: Shape (num_scenes, griz_size**3 // 8) + YY: + grid_size, dt_gamma, perturb, T_thresh are deleted + code: Shape (num_scenes, *code_size) + cameras: Shape (num_scenes, num_imgs, 19(3+16)) + smpl_params: Shape (num_scenes, 189) + + """ + if isinstance(code, list): + num_scenes = len(code[0]) + else: + num_scenes = len(code) + assert num_scenes > 0 + self.iter+=1 + + image = [] + scales = [] + norm = [] if return_norm else None + viz_masks = [] if not self.training else None + start_time = time.time() + xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) + end_time_to_3D = time.time() + time_code_to_3d = end_time_to_3D- start_time + + R_delta = batch_rodrigues(rot.reshape(-1, 3)) + R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) + R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) + normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) + R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) + if 1: + for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): + image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ + radius=radius_single, return_norm=False, return_viz=not self.training) + image.append(image_single) + scales.append(scale.unsqueeze(0)) + if return_norm: + norm.append(norm_single) + if not self.training: + viz_masks.append(viz_mask) + image = torch.cat(image, dim=0) + scales = torch.cat(scales, dim=0) + + norm = torch.cat(norm, dim=0) if return_norm else None + viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None + + time_3D_to_img = time.time() - end_time_to_3D + + + if False: + image = image.to(torch.bfloat16) + scales = scales.to(torch.bfloat16) + if return_norm: + norm = norm.to(torch.bfloat16) + if viz_masks is not None: + viz_masks = viz_masks.to(torch.bfloat16) + offsets = offsets.to(torch.bfloat16) + + results = dict( + image=image) + return results, time_code_to_3d, time_3D_to_img \ No newline at end of file diff --git a/lib/utils/infer_util.py b/lib/utils/infer_util.py index 051ec54..5582664 100644 --- a/lib/utils/infer_util.py +++ b/lib/utils/infer_util.py @@ -1,471 +1,498 @@ -import os -import imageio -import rembg -import torch -import numpy as np -import PIL.Image -from PIL import Image -from typing import Any -import json - -from pathlib import Path -from torchvision.transforms import ToTensor -from rembg import remove # For background removal -from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle -from lib.models.deformers.smplx.lbs import batch_rodrigues -import cv2 -from PIL import Image -import numpy as np - -import json -# import random -import math -# import av - - -def reset_first_frame_rotation(root_orient, trans): - """ - Set the root_orient rotation matrix of the first frame to the identity matrix (no rotation), - keep the relative rotation relationships of other frames, and adjust trans accordingly. - - Parameters: - root_orient: Tensor of shape (N, 3), representing the axis-angle parameters for N frames. - trans: Tensor of shape (N, 3), representing the translation parameters for N frames. - - Returns: - new_root_orient: Tensor of shape (N, 3), adjusted axis-angle parameters. - new_trans: Tensor of shape (N, 3), adjusted translation parameters. - """ - # Convert the root_orient of the first frame to a rotation matrix - R_0 = axis_angle_to_matrix(root_orient[0:1]) # Shape: (1, 3, 3) - - # Compute the inverse of the first frame's rotation matrix - R_0_inv = torch.inverse(R_0) # Shape: (1, 3, 3) - - # Initialize lists for new root_orient and trans - new_root_orient = [] - new_trans = [] - - for i in range(root_orient.shape[0]): - # Rotation matrix of the current frame - R_i = axis_angle_to_matrix(root_orient[i:i+1]) # Shape: (1, 3, 3) - R_new = torch.matmul(R_0_inv, R_i) # Shape: (1, 3, 3) - - # Convert the rotation matrix back to axis-angle representation - axis_angle_new = matrix_to_axis_angle(R_new) # Shape: (1, 3) - new_root_orient.append(axis_angle_new) - - # Adjust the translation for the current frame - trans_i = trans[i:i+1] # Shape: (1, 3) - trans_new = torch.matmul(R_0_inv, trans_i.T).T # Shape: (1, 3) - new_trans.append(trans_new) - - # Stack the results of new_root_orient and new_trans - new_root_orient = torch.cat(new_root_orient, dim=0) # Shape: (N, 3) - new_trans = torch.cat(new_trans, dim=0) # Shape: (N, 3) - - # Adjust the new translations relative to the first frame - new_trans = new_trans - new_trans[[0], :] - - return new_root_orient, new_trans - -from scipy.spatial.transform import Rotation -def rotation_matrix_to_rodrigues(rotation_matrices): - # reshape rotation_matrices to (-1, 3, 3) - reshaped_matrices = rotation_matrices.reshape(-1, 3, 3) - rotation = Rotation.from_matrix(reshaped_matrices) - rodrigues_vectors = rotation.as_rotvec() - return rodrigues_vectors - - - -def get_hand_pose_mean(): - import numpy as np - hand_pose_mean= np.array([[ 0.11167871, 0.04289218, -0.41644183, 0.10881133, -0.06598568, - -0.75622 , -0.09639297, -0.09091566, -0.18845929, -0.11809504, - 0.05094385, -0.5295845 , -0.14369841, 0.0552417 , -0.7048571 , - -0.01918292, -0.09233685, -0.3379135 , -0.45703298, -0.19628395, - -0.6254575 , -0.21465237, -0.06599829, -0.50689423, -0.36972436, - -0.06034463, -0.07949023, -0.1418697 , -0.08585263, -0.63552827, - -0.3033416 , -0.05788098, -0.6313892 , -0.17612089, -0.13209307, - -0.37335458, 0.8509643 , 0.27692273, -0.09154807, -0.49983943, - 0.02655647, 0.05288088, 0.5355592 , 0.04596104, -0.27735803, - 0.11167871, -0.04289218, 0.41644183, 0.10881133, 0.06598568, - 0.75622 , -0.09639297, 0.09091566, 0.18845929, -0.11809504, - -0.05094385, 0.5295845 , -0.14369841, -0.0552417 , 0.7048571 , - -0.01918292, 0.09233685, 0.3379135 , -0.45703298, 0.19628395, - 0.6254575 , -0.21465237, 0.06599829, 0.50689423, -0.36972436, - 0.06034463, 0.07949023, -0.1418697 , 0.08585263, 0.63552827, - -0.3033416 , 0.05788098, 0.6313892 , -0.17612089, 0.13209307, - 0.37335458, 0.8509643 , -0.27692273, 0.09154807, -0.49983943, - -0.02655647, -0.05288088, 0.5355592 , -0.04596104, 0.27735803]]) - return hand_pose_mean - - -def load_smplify_json(smplx_smplify_path): - with open(smplx_smplify_path) as f: - data = json.load(f) - - # Prepare camera transformation matrix (R | t) - RT = torch.concatenate([torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3, 1) * 2], dim=1) - RT = torch.cat([RT, torch.Tensor([[0, 0, 0, 1]])], dim=0) - - # Create intrinsic parameters tensor - intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt']) - # intri[[3, 2]] = intri[[2, 3]] - - # # Set default focal length and image resolution - # default_focal = 1120 # Default focal length - # img_res = [640, 896] - # default_fxy_cxy = torch.tensor([default_focal, default_focal, img_res[1] // 2, img_res[0] // 2]).reshape(1, 4) - - # # Adjust intrinsic parameters based on default focal and resolution - # intri = intri * default_fxy_cxy[0, -2] / intri[-2] - # intri[-2:] = default_fxy_cxy[0, -2:] # Force consistent image width and height - - # Extract SMPL parameters from data - smpl_param_data = data - global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1) - body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1) - shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10] - left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1) - right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1) - - # Concatenate all parameters into a single tensor for SMPL model - smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1, 3), - global_orient, body_pose, shape, left_hand_pose, right_hand_pose, - np.array(smpl_param_data['jaw_pose']).reshape(1, -1), - np.zeros_like(np.array(smpl_param_data['leye_pose']).reshape(1, -1)), - np.zeros_like(np.array(smpl_param_data['reye_pose']).reshape(1, -1)), - np.zeros_like(np.array(smpl_param_data['expr']).reshape(1, -1)[:, :10])], axis=1) - - return RT, intri, torch.Tensor(smpl_param_ref).reshape(-1) # Return transformation, intrinsic, and SMPL parameters - -def load_image(input_path, output_folder, image_frame_ratio=None): - input_img_path = Path(input_path) - - vids = [] - save_path = os.path.join(output_folder, f"{input_img_path.name}") - print(f"Processing: {save_path}") - image = Image.open(input_img_path) - - if image.mode == "RGBA": - pass - else: - # remove bg - image = remove(image.convert("RGBA"), alpha_matting=True) - - # resize object in frame - image_arr = np.array(image) - in_w, in_h = image_arr.shape[:2] - ret, mask = cv2.threshold( - np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY - ) - x, y, w, h = cv2.boundingRect(mask) - max_size = max(w, h) - side_len = ( - int(max_size / image_frame_ratio) - if image_frame_ratio is not None - else int(max_size / 0.85) - ) - padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) - center = side_len // 2 - padded_image[ - center - h // 2 : center - h // 2 + h, - center - w // 2 : center - w // 2 + w, - ] = image_arr[y : y + h, x : x + w] - rgba = Image.fromarray(padded_image).resize((896, 896), Image.LANCZOS) - # crop the width into 640 in the center - rgba = rgba.crop([128, 0, 640+128, 896]) - # white bg - rgba_arr = np.array(rgba) / 255.0 - rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) - input_image = Image.fromarray((rgb * 255).astype(np.uint8)) - - image = ToTensor()(input_image) - - return image - - - -def prepare_camera( resolution_x = 640, resolution_y = 640, focal_length = 600,sensor_width = 32, camera_dist = 20, num_views=1, stides=1): - - def look_at(camera_position, target_position, up_vector): # colmap +z forward, +y down - forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position) - right = np.cross(up_vector, forward) - up = np.cross(forward, right) - return np.column_stack((right, up, forward)) - - # set the intrisics - focal_length = focal_length * (resolution_y/sensor_width) - - K = np.array( - [[focal_length, 0, resolution_x//2], - [0, focal_length, resolution_y//2], - [0, 0, 1]] - ) - - # set the extrisics - camera_pose_list = [] - for frame_idx in range(0, num_views, stides): - - phi = math.radians(90) - theta = (3 / 4) * math.pi * 2 - camera_location = np.array( - [camera_dist * math.sin(phi) * math.cos(theta), - - camera_dist * math.cos(phi), - -camera_dist * math.sin(phi) * math.sin(theta),] - ) - # print(camera_location) - camera_pose = np.eye(4) - camera_pose[:3, 3] = camera_location - - # Set camera position and target position - camera_position = camera_location - target_position = np.array([0.0, 0.0, 0.0]) - - # Compute the camera's rotation matrix to look at the target - up_vector = np.array([0.0, -1.0, 0.0]) # colmap - rotation_matrix = look_at(camera_position, target_position, up_vector) - - # Update camera position and rotation - camera_pose[:3, :3] = rotation_matrix - camera_pose[:3, 3] = camera_position - camera_pose_list.append(camera_pose) - return K, camera_pose_list - - -def construct_camera(K, cam_list, device='cuda'): - num_imgs = len(cam_list) - front_idx = num_imgs//4*3 - cam_list = cam_list[front_idx:] + cam_list[:front_idx] - cam_raw = np.array(cam_list) - cam_raw[:, :3, 3] = cam_raw[:, :3, 3] - cam = np.linalg.inv(cam_raw) - cam = torch.Tensor(cam) - intrics = torch.Tensor([K[0,0],K[1,1], K[0,2], K[1,2]]).reshape(-1) - scale = 0.5 - # diffrent from the synthetic data, the scale is process first - # trans from (3,) to (batch_size, 3,1) - trans = [0, 0.2, 0] #in the center - trans_bt = torch.Tensor(trans).reshape(1, 3, 1).expand(cam.shape[0], 3, 1) - cam[:,:3,3] = cam[:,:3,3] + torch.bmm(cam[:,:3,:3], trans_bt).reshape(-1, 3) # T = Rt+T torch.Size([24, 3, 1]) - cam[:,:3,:3] = cam[:,:3,:3] * scale # R = sR - cam_c2w = torch.inverse(cam) - cam_w2c = cam - poses = [] - for i_cam in range(cam.shape[0]): - poses.append( torch.concat([ - (intrics.reshape(-1)).to(torch.float32), #C ! # C ? T 理论上要给C - (cam_w2c[i_cam]).to(torch.float32).reshape(-1), # RT #Rt|C ? RT 理论上要给RT - ], dim=0)) - cameras = torch.stack(poses).to(device) # [N, 19] - return cameras - -def get_name_str(name): - path_ = os.path.basename(os.path.dirname(name)) + os.path.basename(name) - return path_ - - - -def load_smplx_from_npy(smplx_path, device='cuda'): - hand_mean = get_hand_pose_mean().reshape(-1) - smplx_pose_param = np.load(smplx_path, allow_pickle=True) - # if "person1" in smplx_pose_param: - # smplx_pose_param = smplx_pose_param['person1'] - smplx_pose_param = { - 'root_orient': smplx_pose_param[:, :3], # controls the global root orientation - 'pose_body': smplx_pose_param[:, 3:3+63], # controls the body - 'pose_hand': smplx_pose_param[:, 66:66+90], # controls the finger articulation - 'pose_jaw': smplx_pose_param[:, 66+90:66+93], # controls the yaw pose - 'face_expr': smplx_pose_param[:, 159:159+50], # controls the face expression - 'face_shape': smplx_pose_param[:, 209:209+100], # controls the face shape - 'trans': smplx_pose_param[:, 309:309+3], # controls the global body position - 'betas': smplx_pose_param[:, 312:], # controls the body shape. Body shape is static - } - - smplx_param_list = [] - for i in range(1, 1800, 1): - # for i in k.keys(): - # k[i] = np.array(k[i]) - left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, - -0.6652, -0.7290, 0.0084, -0.4818]) - betas = torch.zeros((10)) - smplx_param = \ - np.concatenate([np.array([1]), smplx_pose_param['trans'][i], smplx_pose_param['root_orient'][i], \ - smplx_pose_param['pose_body'][i],betas, \ - smplx_pose_param['pose_hand'][i]-hand_mean, smplx_pose_param['pose_jaw'][i], np.zeros(6), smplx_pose_param['face_expr'][i][:10]], axis=0).reshape(1,-1) - smplx_param_list.append(smplx_param) - smplx_params = np.concatenate(smplx_param_list, 0) - smpl_params = torch.Tensor(smplx_params).to(device) - return smpl_params -def add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'): - from cv2 import Rodrigues - initial_matrix = batch_rodrigues(smpl_tmp.reshape(1,189)[:, 4:7]).cpu().numpy().copy() - # Rotate a rotation matrix by 360 degrees around the y-axis. - # frames_num = 180 - all_smpl = [] - # Combine the rotations - all_smpl = [] - for idx_f in range(frames_num): - new_smpl = smpl_tmp.clone() - angle = 360//frames_num * idx_f - y_angle = np.radians(angle) - y_rotation_matrix = np.array([ - [ np.cos(y_angle),0, np.sin(y_angle)], - [0, 1, 0], - [-np.sin(y_angle), 0, np.cos(y_angle)], - ]) - final_matrix = y_rotation_matrix[None] @ initial_matrix - - new_smpl[4:7] = torch.Tensor(rotation_matrix_to_rodrigues(torch.Tensor(final_matrix))).to(device) - all_smpl.append(new_smpl) - all_smpl = torch.stack(all_smpl, 0) - smpl_params = all_smpl.to(device) - return smpl_params - -def load_smplx_from_json(smplx_path, device='cuda'): - # format of motion-x - hand_mean = get_hand_pose_mean().reshape(-1) - with open(smplx_path, 'r') as f: - smplx_pose_param = json.load(f) - smplx_param_list = [] - for par in smplx_pose_param['annotations']: - k = par['smplx_params'] - for i in k.keys(): - k[i] = np.array(k[i]) - - betas = torch.zeros((10)) - # ######### wrist pose fix ################ - smplx_param = \ - np.concatenate([np.array([1]), k['trans'], - k['root_orient']*np.array([1, 1, 1]), \ - k['pose_body'],betas, \ - k['pose_hand']-hand_mean, k['pose_jaw'], np.zeros(6), np.zeros_like(k['face_expr'][:10])], axis=0).reshape(1,-1) - smplx_param_list.append(smplx_param) - - - smplx_params = np.concatenate(smplx_param_list, 0) - print(smplx_params.shape) - smpl_params = torch.Tensor(smplx_params).to(device) - return smpl_params - -def get_image_dimensions(input_path): - with Image.open(input_path) as img: - return img.height, img.width - -def construct_camera_from_motionx(smplx_path, device='cuda'): - with open(smplx_path, 'r') as f: - smplx_pose_param = json.load(f) - cam_exts = [] - cam_ints = [] - for par in smplx_pose_param['annotations']: - cam = par['cam_params'] - R = np.array(cam['cam_R']) - K = np.array(cam['intrins']) - T = np.array(cam['cam_T']) - cam['cam_T'][1] = -cam['cam_T'][1] - cam['cam_T'][2] = -cam['cam_T'][2] - extrix = np.eye(4) - extrix[:3, :3] = R - extrix[:3,3] = T - cam_exts.append(extrix) - intrix = K - cam_ints.append(intrix) - - # target N,20 - cam_exts_array = np.array(cam_exts) - - cam_exts_stack = torch.Tensor(cam_exts_array).to(device).reshape(-1, 16) - cam_ints_stack = torch.Tensor(cam_ints).to(device).reshape(-1, 4) - cameras = torch.cat([cam_ints_stack, cam_exts_stack], dim=-1).reshape(-1,1, 20) - return cameras - -def remove_background(image: PIL.Image.Image, - rembg_session: Any = None, - force: bool = False, - **rembg_kwargs, -) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - - -def resize_foreground( - image: PIL.Image.Image, - ratio: float, -) -> PIL.Image.Image: - image = np.array(image) - assert image.shape[-1] == 4 - alpha = np.where(image[..., 3] > 0) - y1, y2, x1, x2 = ( - alpha[0].min(), - alpha[0].max(), - alpha[1].min(), - alpha[1].max(), - ) - # crop the foreground - fg = image[y1:y2, x1:x2] - # pad to square - size = max(fg.shape[0], fg.shape[1]) - ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 - ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 - new_image = np.pad( - fg, - ((ph0, ph1), (pw0, pw1), (0, 0)), - mode="constant", - constant_values=((0, 0), (0, 0), (0, 0)), - ) - - # compute padding according to the ratio - new_size = int(new_image.shape[0] / ratio) - # pad to size, double side - ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 - ph1, pw1 = new_size - size - ph0, new_size - size - pw0 - new_image = np.pad( - new_image, - ((ph0, ph1), (pw0, pw1), (0, 0)), - mode="constant", - constant_values=((0, 0), (0, 0), (0, 0)), - ) - new_image = PIL.Image.fromarray(new_image) - return new_image - - -def images_to_video( - images: torch.Tensor, - output_path: str, - fps: int = 30, -) -> None: - # images: (N, C, H, W) - video_dir = os.path.dirname(output_path) - video_name = os.path.basename(output_path) - os.makedirs(video_dir, exist_ok=True) - - frames = [] - for i in range(len(images)): - frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) - assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ - f"Frame shape mismatch: {frame.shape} vs {images.shape}" - assert frame.min() >= 0 and frame.max() <= 255, \ - f"Frame value out of range: {frame.min()} ~ {frame.max()}" - frames.append(frame) - imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) - - -def save_video( - frames: torch.Tensor, - output_path: str, - fps: int = 30, -) -> None: - # images: (N, C, H, W) - frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] - writer = imageio.get_writer(output_path, fps=fps) - for frame in frames: - writer.append_data(frame) - writer.close() \ No newline at end of file +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any +import json + +from pathlib import Path +from torchvision.transforms import ToTensor +from rembg import remove # For background removal +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle +from lib.models.deformers.smplx.lbs import batch_rodrigues +import cv2 +from PIL import Image +import numpy as np + +import json +# import random +import math +# import av + + +def reset_first_frame_rotation(root_orient, trans): + """ + Set the root_orient rotation matrix of the first frame to the identity matrix (no rotation), + keep the relative rotation relationships of other frames, and adjust trans accordingly. + + Parameters: + root_orient: Tensor of shape (N, 3), representing the axis-angle parameters for N frames. + trans: Tensor of shape (N, 3), representing the translation parameters for N frames. + + Returns: + new_root_orient: Tensor of shape (N, 3), adjusted axis-angle parameters. + new_trans: Tensor of shape (N, 3), adjusted translation parameters. + """ + # Convert the root_orient of the first frame to a rotation matrix + R_0 = axis_angle_to_matrix(root_orient[0:1]) # Shape: (1, 3, 3) + + # Compute the inverse of the first frame's rotation matrix + R_0_inv = torch.inverse(R_0) # Shape: (1, 3, 3) + + # Initialize lists for new root_orient and trans + new_root_orient = [] + new_trans = [] + + for i in range(root_orient.shape[0]): + # Rotation matrix of the current frame + R_i = axis_angle_to_matrix(root_orient[i:i+1]) # Shape: (1, 3, 3) + R_new = torch.matmul(R_0_inv, R_i) # Shape: (1, 3, 3) + + # Convert the rotation matrix back to axis-angle representation + axis_angle_new = matrix_to_axis_angle(R_new) # Shape: (1, 3) + new_root_orient.append(axis_angle_new) + + # Adjust the translation for the current frame + trans_i = trans[i:i+1] # Shape: (1, 3) + trans_new = torch.matmul(R_0_inv, trans_i.T).T # Shape: (1, 3) + new_trans.append(trans_new) + + # Stack the results of new_root_orient and new_trans + new_root_orient = torch.cat(new_root_orient, dim=0) # Shape: (N, 3) + new_trans = torch.cat(new_trans, dim=0) # Shape: (N, 3) + + # Adjust the new translations relative to the first frame + new_trans = new_trans - new_trans[[0], :] + + return new_root_orient, new_trans + +from scipy.spatial.transform import Rotation +def rotation_matrix_to_rodrigues(rotation_matrices): + # reshape rotation_matrices to (-1, 3, 3) + reshaped_matrices = rotation_matrices.reshape(-1, 3, 3) + rotation = Rotation.from_matrix(reshaped_matrices) + rodrigues_vectors = rotation.as_rotvec() + return rodrigues_vectors + + + +def get_hand_pose_mean(): + import numpy as np + hand_pose_mean= np.array([[ 0.11167871, 0.04289218, -0.41644183, 0.10881133, -0.06598568, + -0.75622 , -0.09639297, -0.09091566, -0.18845929, -0.11809504, + 0.05094385, -0.5295845 , -0.14369841, 0.0552417 , -0.7048571 , + -0.01918292, -0.09233685, -0.3379135 , -0.45703298, -0.19628395, + -0.6254575 , -0.21465237, -0.06599829, -0.50689423, -0.36972436, + -0.06034463, -0.07949023, -0.1418697 , -0.08585263, -0.63552827, + -0.3033416 , -0.05788098, -0.6313892 , -0.17612089, -0.13209307, + -0.37335458, 0.8509643 , 0.27692273, -0.09154807, -0.49983943, + 0.02655647, 0.05288088, 0.5355592 , 0.04596104, -0.27735803, + 0.11167871, -0.04289218, 0.41644183, 0.10881133, 0.06598568, + 0.75622 , -0.09639297, 0.09091566, 0.18845929, -0.11809504, + -0.05094385, 0.5295845 , -0.14369841, -0.0552417 , 0.7048571 , + -0.01918292, 0.09233685, 0.3379135 , -0.45703298, 0.19628395, + 0.6254575 , -0.21465237, 0.06599829, 0.50689423, -0.36972436, + 0.06034463, 0.07949023, -0.1418697 , 0.08585263, 0.63552827, + -0.3033416 , 0.05788098, 0.6313892 , -0.17612089, 0.13209307, + 0.37335458, 0.8509643 , -0.27692273, 0.09154807, -0.49983943, + -0.02655647, -0.05288088, 0.5355592 , -0.04596104, 0.27735803]]) + return hand_pose_mean + + +def load_smplify_json(smplx_smplify_path): + with open(smplx_smplify_path) as f: + data = json.load(f) + + # Prepare camera transformation matrix (R | t) + RT = torch.concatenate([torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3, 1) * 2], dim=1) + RT = torch.cat([RT, torch.Tensor([[0, 0, 0, 1]])], dim=0) + + # Create intrinsic parameters tensor + intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt']) + # intri[[3, 2]] = intri[[2, 3]] + + # # Set default focal length and image resolution + # default_focal = 1120 # Default focal length + # img_res = [640, 896] + # default_fxy_cxy = torch.tensor([default_focal, default_focal, img_res[1] // 2, img_res[0] // 2]).reshape(1, 4) + + # # Adjust intrinsic parameters based on default focal and resolution + # intri = intri * default_fxy_cxy[0, -2] / intri[-2] + # intri[-2:] = default_fxy_cxy[0, -2:] # Force consistent image width and height + + # Extract SMPL parameters from data + smpl_param_data = data + global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1) + body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1) + shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10] + left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1) + right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1) + + # Concatenate all parameters into a single tensor for SMPL model + smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1, 3), + global_orient, body_pose, shape, left_hand_pose, right_hand_pose, + np.array(smpl_param_data['jaw_pose']).reshape(1, -1), + np.zeros_like(np.array(smpl_param_data['leye_pose']).reshape(1, -1)), + np.zeros_like(np.array(smpl_param_data['reye_pose']).reshape(1, -1)), + np.zeros_like(np.array(smpl_param_data['expr']).reshape(1, -1)[:, :10])], axis=1) + + return RT, intri, torch.Tensor(smpl_param_ref).reshape(-1) # Return transformation, intrinsic, and SMPL parameters + +def load_image(input_path, output_folder, image_frame_ratio=None): + input_img_path = Path(input_path) + + vids = [] + save_path = os.path.join(output_folder, f"{input_img_path.name}") + print(f"Processing: {save_path}") + image = Image.open(input_img_path) + + if image.mode == "RGBA": + pass + else: + # remove bg + image = remove(image.convert("RGBA"), alpha_matting=True) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else int(max_size / 0.85) + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + rgba = Image.fromarray(padded_image).resize((896, 896), Image.LANCZOS) + # crop the width into 640 in the center + rgba = rgba.crop([128, 0, 640+128, 896]) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + + image = ToTensor()(input_image) + + return image + + + +def prepare_camera( resolution_x = 640, resolution_y = 640, focal_length = 600,sensor_width = 32, camera_dist = 20, num_views=1, stides=1): + + def look_at(camera_position, target_position, up_vector): # colmap +z forward, +y down + forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position) + right = np.cross(up_vector, forward) + up = np.cross(forward, right) + return np.column_stack((right, up, forward)) + + # set the intrisics + focal_length = focal_length * (resolution_y/sensor_width) + + K = np.array( + [[focal_length, 0, resolution_x//2], + [0, focal_length, resolution_y//2], + [0, 0, 1]] + ) + + # set the extrisics + camera_pose_list = [] + for frame_idx in range(0, num_views, stides): + + phi = math.radians(90) + theta = (3 / 4) * math.pi * 2 + camera_location = np.array( + [camera_dist * math.sin(phi) * math.cos(theta), + + camera_dist * math.cos(phi), + -camera_dist * math.sin(phi) * math.sin(theta),] + ) + # print(camera_location) + camera_pose = np.eye(4) + camera_pose[:3, 3] = camera_location + + # Set camera position and target position + camera_position = camera_location + target_position = np.array([0.0, 0.0, 0.0]) + + # Compute the camera's rotation matrix to look at the target + up_vector = np.array([0.0, -1.0, 0.0]) # colmap + rotation_matrix = look_at(camera_position, target_position, up_vector) + + # Update camera position and rotation + camera_pose[:3, :3] = rotation_matrix + camera_pose[:3, 3] = camera_position + camera_pose_list.append(camera_pose) + return K, camera_pose_list + + +def construct_camera(K, cam_list, device='cuda'): + num_imgs = len(cam_list) + front_idx = num_imgs//4*3 + cam_list = cam_list[front_idx:] + cam_list[:front_idx] + cam_raw = np.array(cam_list) + cam_raw[:, :3, 3] = cam_raw[:, :3, 3] + cam = np.linalg.inv(cam_raw) + cam = torch.Tensor(cam) + intrics = torch.Tensor([K[0,0],K[1,1], K[0,2], K[1,2]]).reshape(-1) + scale = 0.5 + # diffrent from the synthetic data, the scale is process first + # trans from (3,) to (batch_size, 3,1) + trans = [0, 0.2, 0] #in the center + trans_bt = torch.Tensor(trans).reshape(1, 3, 1).expand(cam.shape[0], 3, 1) + cam[:,:3,3] = cam[:,:3,3] + torch.bmm(cam[:,:3,:3], trans_bt).reshape(-1, 3) # T = Rt+T torch.Size([24, 3, 1]) + cam[:,:3,:3] = cam[:,:3,:3] * scale # R = sR + cam_c2w = torch.inverse(cam) + cam_w2c = cam + poses = [] + for i_cam in range(cam.shape[0]): + poses.append( torch.concat([ + (intrics.reshape(-1)).to(torch.float32), #C ! # C ? T 理论上要给C + (cam_w2c[i_cam]).to(torch.float32).reshape(-1), # RT #Rt|C ? RT 理论上要给RT + ], dim=0)) + cameras = torch.stack(poses).to(device) # [N, 19] + return cameras + +def get_name_str(name): + path_ = os.path.basename(os.path.dirname(name)) + os.path.basename(name) + return path_ + + + +def load_smplx_from_npy(smplx_path, device='cuda'): + hand_mean = get_hand_pose_mean().reshape(-1) + smplx_pose_param = np.load(smplx_path, allow_pickle=True) + # if "person1" in smplx_pose_param: + # smplx_pose_param = smplx_pose_param['person1'] + smplx_pose_param = { + 'root_orient': smplx_pose_param[:, :3], # controls the global root orientation + 'pose_body': smplx_pose_param[:, 3:3+63], # controls the body + 'pose_hand': smplx_pose_param[:, 66:66+90], # controls the finger articulation + 'pose_jaw': smplx_pose_param[:, 66+90:66+93], # controls the yaw pose + 'face_expr': smplx_pose_param[:, 159:159+50], # controls the face expression + 'face_shape': smplx_pose_param[:, 209:209+100], # controls the face shape + 'trans': smplx_pose_param[:, 309:309+3], # controls the global body position + 'betas': smplx_pose_param[:, 312:], # controls the body shape. Body shape is static + } + + smplx_param_list = [] + for i in range(1, 1800, 1): + # for i in k.keys(): + # k[i] = np.array(k[i]) + left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, + -0.6652, -0.7290, 0.0084, -0.4818]) + betas = torch.zeros((10)) + smplx_param = \ + np.concatenate([np.array([1]), smplx_pose_param['trans'][i], smplx_pose_param['root_orient'][i], \ + smplx_pose_param['pose_body'][i],betas, \ + smplx_pose_param['pose_hand'][i]-hand_mean, smplx_pose_param['pose_jaw'][i], np.zeros(6), smplx_pose_param['face_expr'][i][:10]], axis=0).reshape(1,-1) + smplx_param_list.append(smplx_param) + smplx_params = np.concatenate(smplx_param_list, 0) + smpl_params = torch.Tensor(smplx_params).to(device) + return smpl_params +def add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'): + from cv2 import Rodrigues + initial_matrix = batch_rodrigues(smpl_tmp.reshape(1,189)[:, 4:7]).cpu().numpy().copy() + # Rotate a rotation matrix by 360 degrees around the y-axis. + # frames_num = 180 + all_smpl = [] + # Combine the rotations + all_smpl = [] + for idx_f in range(frames_num): + new_smpl = smpl_tmp.clone() + angle = 360//frames_num * idx_f + y_angle = np.radians(angle) + y_rotation_matrix = np.array([ + [ np.cos(y_angle),0, np.sin(y_angle)], + [0, 1, 0], + [-np.sin(y_angle), 0, np.cos(y_angle)], + ]) + final_matrix = y_rotation_matrix[None] @ initial_matrix + + new_smpl[4:7] = torch.Tensor(rotation_matrix_to_rodrigues(torch.Tensor(final_matrix))).to(device) + all_smpl.append(new_smpl) + all_smpl = torch.stack(all_smpl, 0) + smpl_params = all_smpl.to(device) + return smpl_params + +def load_smplx_from_json(smplx_path, device='cuda'): + # format of motion-x + hand_mean = get_hand_pose_mean().reshape(-1) + with open(smplx_path, 'r') as f: + smplx_pose_param = json.load(f) + smplx_param_list = [] + for par in smplx_pose_param['annotations']: + k = par['smplx_params'] + for i in k.keys(): + k[i] = np.array(k[i]) + + betas = torch.zeros((10)) + # ######### wrist pose fix ################ + smplx_param = \ + np.concatenate([np.array([1]), k['trans'], + k['root_orient']*np.array([1, 1, 1]), \ + k['pose_body'],betas, \ + k['pose_hand']-hand_mean, k['pose_jaw'], np.zeros(6), np.zeros_like(k['face_expr'][:10])], axis=0).reshape(1,-1) + smplx_param_list.append(smplx_param) + + + smplx_params = np.concatenate(smplx_param_list, 0) + print(smplx_params.shape) + smpl_params = torch.Tensor(smplx_params).to(device) + return smpl_params + +def get_image_dimensions(input_path): + with Image.open(input_path) as img: + return img.height, img.width + +def construct_camera_from_motionx(smplx_path, device='cuda'): + with open(smplx_path, 'r') as f: + smplx_pose_param = json.load(f) + cam_exts = [] + cam_ints = [] + for par in smplx_pose_param['annotations']: + cam = par['cam_params'] + R = np.array(cam['cam_R']) + K = np.array(cam['intrins']) + T = np.array(cam['cam_T']) + cam['cam_T'][1] = -cam['cam_T'][1] + cam['cam_T'][2] = -cam['cam_T'][2] + extrix = np.eye(4) + extrix[:3, :3] = R + extrix[:3,3] = T + cam_exts.append(extrix) + intrix = K + cam_ints.append(intrix) + + # target N,20 + cam_exts_array = np.array(cam_exts) + + cam_exts_stack = torch.Tensor(cam_exts_array).to(device).reshape(-1, 16) + cam_ints_stack = torch.Tensor(cam_ints).to(device).reshape(-1, 4) + cameras = torch.cat([cam_ints_stack, cam_exts_stack], dim=-1).reshape(-1,1, 20) + return cameras + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) + + +def save_video( + frames: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + """ + 使用OpenCV保存视频(自动处理BGR到RGB转换) + + 参数: + frames: 输入帧序列,形状为 (N, C, H, W) 的torch.Tensor + output_path: 输出视频路径 + fps: 帧率(默认30) + """ + # 检查输入张量维度 + if frames.dim() != 4: + raise ValueError(f"输入frames应该是4D张量 (N, C, H, W),但得到的是 {frames.dim()}D") + + # 转换为numpy数组并调整格式 + frames_np = frames.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, C) + frames_np = (frames_np * 255).astype(np.uint8) # 转换为0-255范围 + + # 获取视频尺寸 + height, width = frames_np.shape[1:3] + + # 创建VideoWriter + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4编码 + writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + for frame in frames_np: + # OpenCV使用BGR格式,如果输入是RGB需要转换 + if frame.shape[2] == 3: # 仅当是彩色图像时转换 + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + else: + frame_bgr = frame # 灰度图像不转换 + + writer.write(frame_bgr) + + writer.release() \ No newline at end of file diff --git a/lib/utils/train_util.py b/lib/utils/train_util.py index cf27114..c07c65f 100644 --- a/lib/utils/train_util.py +++ b/lib/utils/train_util.py @@ -1,35 +1,54 @@ -import importlib - -import os - -from pytorch_lightning.utilities import rank_zero_only -@rank_zero_only -def main_print(*args): - print(*args) - - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - main_print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == '__is_first_stage__': - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - main_print(string) - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) +import importlib + +import os + +from pytorch_lightning.utilities import rank_zero_only +@rank_zero_only +def main_print(*args): + print(*args) + + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + main_print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + main_print(string) + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def get_class_from_config(config): + class_path = config.get("target") + if not class_path: + raise KeyError("Expected key `target` to instantiate.") + + if "." not in class_path: + raise ValueError(f"Invalid class path: '{class_path}'. Expected format 'module.submodule.ClassName'") + + module_path, class_name = class_path.rsplit(".", 1) + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as e: + raise ImportError(f"Module '{module_path}' not found") from e + + if not hasattr(module, class_name): + raise AttributeError(f"Class '{class_name}' not found in module '{module_path}'") + + return getattr(module, class_name) # 返回类对象 diff --git a/run_demo.py b/run_demo.py index 4d3dc49..0d9397d 100644 --- a/run_demo.py +++ b/run_demo.py @@ -1,295 +1,302 @@ -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "0" -# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" -import argparse -import torch -from tqdm import tqdm -from torchvision.transforms import v2 -from pytorch_lightning import seed_everything -from omegaconf import OmegaConf -from tqdm import tqdm -from einops import rearrange -from lib.utils.infer_util import * -from lib.utils.train_util import instantiate_from_config -import torchvision -import json -############################################################################### -# Arguments. -############################################################################### - -def parse_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='Path to config file.', required=False) - parser.add_argument('--input_path', type=str, help='Path to input image or directory.', required=False) - parser.add_argument('--resume_path', type=str, help='Path to saved ckpt.', required=False) - parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') - parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') - parser.add_argument('--distance', type=float, default=1.5, help='Render distance.') - parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') - parser.add_argument('--render_mode', type=str, default='novel_pose', - choices=['novel_pose', 'reconstruct', 'novel_pose_A'], - help='Rendering mode: novel_pose (animation), reconstruct (reconstruction), or novel_pose_A (360-degree view with A-pose)') - - return parser.parse_args() - -############################################################################### -# Stage 0: Configuration. -############################################################################### - -device = torch.device('cuda') - - - -def process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_path_list, smplx_path_driven_list): - torch.cuda.set_device(gpu_id) - model = model.cuda() - image_plist = [] - - - render_mode = args.render_mode - - - cam_idx = 0 # 12 # fixed cameras and changes pose for novel poses - num_imgs = 60 - if_load_betas = True - - - if_use_video_cam = False # If the SMPLX sequence provides camera parameters, this can be set to true. - if_uniform_coordinates = True # Normalize the SMPL-X sequence for the purpose of driving. - - - for input_path, smplx_ref_path, smplx_path in tqdm(zip(img_paths_list, smplx_ref_path_list, smplx_path_driven_list), total = len(img_paths_list)): - print(f"Processing: {input_path}") - - args.input_path = input_path - args.input_path_smpl = smplx_ref_path - - # get a name for results - name = get_name_str(args.input_path) + get_name_str(smplx_path) - - ############################################################################### - # Stage 1: Parameters loading - ############################################################################### - - ''' # Stage 1.1: SMPLX loading (Beta)''' - if args.input_path_smpl is not None: - # smplx = np.load(args.input_path_smpl, allow_pickle=True).item() - smplx = json.load(open(args.input_path_smpl)) - if "shapes" in smplx.keys(): - smplx['betas'] = smplx['shapes'] - else: - smplx['betas'] = smplx['betas_save'] - smpl_params = torch.zeros(1, 189).to(device) - if if_load_betas: - smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) - - ''' # Stage 1.2: SMPLX loading (Pose)''' - # animation - if render_mode in ['novel_pose'] : - - if smplx_path.endswith(".npy"): - smpl_params = load_smplx_from_npy(smplx_path) - smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) - - # ========= Note: If the video camera is not used, center everything at the origin ======== - if if_uniform_coordinates: - print("''' Ending --- Adjusting root orientation angles '''") - - # Extract root orientation and translation from SMPL parameters - root_orient = smpl_params[:, 4:7] - trans = smpl_params[:, 1:4] - - # Reset the first frame's rotation and adjust translations - new_root_orient, new_trans = reset_first_frame_rotation(root_orient, trans) - - # Update the root orientation and translation in the SMPL parameters - smpl_params[:, 4:7] = new_root_orient - smpl_params[:, 1:4] = new_trans.squeeze() # Apply the new translation - - - elif smplx_path.endswith(".json"): - ''' for motion-x input ''' - smpl_params = load_smplx_from_json(smplx_path) - smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) - if_use_video_cam = True - - - elif render_mode in ['reconstruct']: - RT_rec, intri_rec, smpl_rec = load_smplify_json(smplx_ref_path) - - H_rec, W_rec = get_image_dimensions(input_path) - - '''Apply root rotation for a full 360-degree view of the object''' - if_add_root_rotate = True - if if_add_root_rotate == True: - - smpl_params = add_root_rotate_to_smplx(smpl_rec, num_imgs) - print(" '''ending --- invert the root angles'''") - else: - smpl_params = smpl_params.to(device) - num_imgs = 1 - - elif render_mode in ['novel_pose_A']: - smpl_params = model.get_default_smplx_params().squeeze() - smpl_params = smpl_params.to(device) - smpl_params = add_root_rotate_to_smplx(smpl_params.clone(), num_imgs) - smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) - - else: - raise NotImplementedError(f"Render mode '{render_mode}' is not supported.") - - '''# Stage 1.3: Image loading ''' - image = load_image(args.input_path, args.output_folders['ref']) - H,W = 896,640 - image_bs = image.unsqueeze(0).to(device) - num_imgs = 180 - - ''' # Stage 1.4 Camera loading''' - if not if_use_video_cam: - # prepare cameras - K, cam_list = prepare_camera(resolution_x=H, resolution_y=W, num_views=num_imgs, stides=1) - cameras = construct_camera(K, cam_list) - - if render_mode == 'novel_pose': # if poses are changed, cameras will be fixed - intrics = torch.Tensor([K[0,0],K[1,1], 256, 256]).reshape(-1) - model.decoder.renderer.image_size = [512, 512] - - assert cameras.shape[-1] == 20 - cameras[:, :4] = intrics - cameras = cameras[cam_idx:cam_idx+1] - num_imgs = smpl_params.shape[0] - cameras = cameras.repeat(num_imgs, 1) - cameras = cameras[:, None, :] # length of the pose sequences - print("modify the render images's resolution into 512x512 ") - - elif render_mode in ['reconstruct']: # using reference smplify's smplx and camera - cameras = torch.concat([intri_rec.reshape(-1,4), RT_rec.reshape(-1, 16)], dim=1) - # H, W = int(intri_rec[2] * 2), int(intri_rec[3] * 2) - model.decoder.renderer.image_size = [W_rec, H_rec]; print(f"modify the render images's resolution into {H_rec}x{W_rec}") - cameras = cameras.reshape(1,1,20).expand(num_imgs,1,-1) - cameras = cameras.cuda() - - elif render_mode == 'novel_pose_A': - model.decoder.renderer.image_size = [W, H] - cameras = cameras[0].reshape(1,1,20).expand(num_imgs,1,-1) - - elif if_use_video_cam: # for the animation with motion-x - cameras = construct_camera_from_motionx(smplx_path) - H, W = 2*cameras[0, 0, [3]].int().item(), 2*cameras[0,0, [2]].int().item() - model.decoder.renderer.image_size = [W, H]; print(f"modify the render images's resolution into {H}x{W}") - # model.decoder.renderer = - - ############################################################################### - # Stage 2: Reconstruction. - ############################################################################### - - sample = image_bs[[0]] # N, 3, H, W, - # if if_use_dataset: - # sample = rearrange(sample, 'b h w c -> b c h w') # N, 3, H, W, - - image_path_idx = os.path.join(args.output_folders['ref'], f'{name}_ref.jpg') - torchvision.utils.save_image( sample[0], image_path_idx) - - with torch.no_grad(): - # get latents - code = model.forward_image_to_uv(sample, is_training=False) - - with torch.no_grad(): - output_list = [] - num_imgs_batch = 5 - total_frames = min(smpl_params.shape[0],300) - res_uv = None - for i in tqdm(range(0, total_frames, num_imgs_batch)): - if i+num_imgs_batch > total_frames: - num_imgs_batch = total_frames - i - code_bt = code.expand(num_imgs_batch, -1, -1, -1) - # cameras_bt = cameras.expand(num_imgs_batch, -1, -1) - cameras_bt = cameras[i:i+num_imgs_batch] - - if render_mode in ['reconstruct', 'novel_pose_A'] and res_uv is not None: - pass - else: - res_uv = model.decoder._decode_feature(code_bt) # Decouple UV attributes - res_points = model.decoder._sample_feature(res_uv) # Sampling - # Animate - res_def_points = model.decoder.deform_pcd(res_points, smpl_params[i:i+num_imgs_batch].to(code_bt.dtype), zeros_hands_off=True, value=0.02) - output = model.decoder.forward_render(res_def_points, cameras_bt.to(code_bt.dtype), num_imgs=1) - image = output["image"][:, 0].cpu().to(torch.float32) - - print("output shape ", output["image"][:, 0].shape) - output_list.append(image) # [:, 0] stands to get the all scenes (poses) - del output - - output = torch.concatenate(output_list, 0) - frames = rearrange(output, "b h w c -> b c h w")#.cpu().numpy() - - video_path_idx = os.path.join(args.output_folders['video'], f'{name}.mp4') - - save_video( - frames[:,:4,...].to(torch.float32), - video_path_idx, - ) - image_plist.append(frames) - print("saving into ", video_path_idx) - return image_plist - -def setup_directories(base_path, config_name): - """Create output directories for results""" - dirs = { - 'image': os.path.join(base_path, config_name, 'images'), - 'video': os.path.join(base_path, config_name), - 'ref': os.path.join(base_path, config_name) - } - for path in dirs.values(): - os.makedirs(path, exist_ok=True) - return dirs - -def main(): - """Main execution function""" - # Parse arguments and set random seed - args = parse_args() - - args.config = "configs/idol_v0.yaml" - args.resume_path = "work_dirs/ckpt/model.ckpt" - - config = OmegaConf.load(args.config) - config_name = os.path.basename(args.config).replace('.yaml', '') - model_config = config.model - - resume_path = args.resume_path - # Initialize model - model = instantiate_from_config(model_config) - model.encoder = model.encoder.to(torch.bfloat16) ; print("moving encoder to bf16") - model = model.__class__.load_from_checkpoint(resume_path, **config.model.params) - model = model.to(device) - model = model.eval() - - # Setup input paths - img_paths_list = ['work_dirs/demo_data/4.jpg'] - smplx_ref_path_list = ['work_dirs/demo_data/4.json'] - smplx_path_driven_list = ['work_dirs/demo_data/Ways_to_Catch_360_clip1.json'] - # smplx_path_driven_list = ['work_dirs/demo_data/finedance-5-144.npy.npy'] - - # Setup output directories - # args.output_path = "./test/" - # args.render_mode = 'reconstruct' # 'novel_pose_A' #'reconstruct' #'novel_pose' - - # make output directories - args.output_folders = setup_directories(args.output_path, config_name) - - # Process data - image_plist = process_data_on_gpu( - args, - model, 0, - img_paths_list, - smplx_ref_path_list, - smplx_path_driven_list - ) - - return image_plist - -if __name__ == "__main__": - main() - - +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +import argparse +import torch +from tqdm import tqdm +from torchvision.transforms import v2 +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf +from tqdm import tqdm +from einops import rearrange +from lib.utils.infer_util import * +from lib.utils.train_util import * +import torchvision +import json +############################################################################### +# Arguments. +############################################################################### + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, help='Path to config file.', required=False) + parser.add_argument('--input_path', type=str, help='Path to input image or directory.', required=False) + parser.add_argument('--resume_path', type=str, help='Path to saved ckpt.', required=False) + parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') + parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') + parser.add_argument('--distance', type=float, default=1.5, help='Render distance.') + parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') + parser.add_argument('--render_mode', type=str, default='novel_pose', + choices=['novel_pose', 'reconstruct', 'novel_pose_A'], + help='Rendering mode: novel_pose (animation), reconstruct (reconstruction), or novel_pose_A (360-degree view with A-pose)') + parser.add_argument('--low_ram', action="store_true", default=False, help='Enabling this option reduces the inference VRAM requirement to 11GB.') + + return parser.parse_args() + +############################################################################### +# Stage 0: Configuration. +############################################################################### + +device = torch.device('cuda') + + + +def process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_path_list, smplx_path_driven_list): + torch.cuda.set_device(gpu_id) + model = model.cuda() + image_plist = [] + + + render_mode = args.render_mode + + + cam_idx = 0 # 12 # fixed cameras and changes pose for novel poses + num_imgs = 60 + if_load_betas = True + + + if_use_video_cam = False # If the SMPLX sequence provides camera parameters, this can be set to true. + if_uniform_coordinates = True # Normalize the SMPL-X sequence for the purpose of driving. + + + for input_path, smplx_ref_path, smplx_path in tqdm(zip(img_paths_list, smplx_ref_path_list, smplx_path_driven_list), total = len(img_paths_list)): + print(f"Processing: {input_path}") + + args.input_path = input_path + args.input_path_smpl = smplx_ref_path + + # get a name for results + name = get_name_str(args.input_path) + get_name_str(smplx_path) + + ############################################################################### + # Stage 1: Parameters loading + ############################################################################### + + ''' # Stage 1.1: SMPLX loading (Beta)''' + if args.input_path_smpl is not None: + # smplx = np.load(args.input_path_smpl, allow_pickle=True).item() + smplx = json.load(open(args.input_path_smpl)) + if "shapes" in smplx.keys(): + smplx['betas'] = smplx['shapes'] + else: + smplx['betas'] = smplx['betas_save'] + smpl_params = torch.zeros(1, 189).to(device) + if if_load_betas: + smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) + + ''' # Stage 1.2: SMPLX loading (Pose)''' + # animation + if render_mode in ['novel_pose'] : + + if smplx_path.endswith(".npy"): + smpl_params = load_smplx_from_npy(smplx_path) + smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) + + # ========= Note: If the video camera is not used, center everything at the origin ======== + if if_uniform_coordinates: + print("''' Ending --- Adjusting root orientation angles '''") + + # Extract root orientation and translation from SMPL parameters + root_orient = smpl_params[:, 4:7] + trans = smpl_params[:, 1:4] + + # Reset the first frame's rotation and adjust translations + new_root_orient, new_trans = reset_first_frame_rotation(root_orient, trans) + + # Update the root orientation and translation in the SMPL parameters + smpl_params[:, 4:7] = new_root_orient + smpl_params[:, 1:4] = new_trans.squeeze() # Apply the new translation + + + elif smplx_path.endswith(".json"): + ''' for motion-x input ''' + smpl_params = load_smplx_from_json(smplx_path) + smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) + if_use_video_cam = True + + + elif render_mode in ['reconstruct']: + RT_rec, intri_rec, smpl_rec = load_smplify_json(smplx_ref_path) + + H_rec, W_rec = get_image_dimensions(input_path) + + '''Apply root rotation for a full 360-degree view of the object''' + if_add_root_rotate = True + if if_add_root_rotate == True: + + smpl_params = add_root_rotate_to_smplx(smpl_rec, num_imgs) + print(" '''ending --- invert the root angles'''") + else: + smpl_params = smpl_params.to(device) + num_imgs = 1 + + elif render_mode in ['novel_pose_A']: + smpl_params = model.get_default_smplx_params().squeeze() + smpl_params = smpl_params.to(device) + smpl_params = add_root_rotate_to_smplx(smpl_params.clone(), num_imgs) + smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) + + else: + raise NotImplementedError(f"Render mode '{render_mode}' is not supported.") + + '''# Stage 1.3: Image loading ''' + image = load_image(args.input_path, args.output_folders['ref']) + H,W = 896,640 + image_bs = image.unsqueeze(0).to(device) + num_imgs = 180 + + ''' # Stage 1.4 Camera loading''' + if not if_use_video_cam: + # prepare cameras + K, cam_list = prepare_camera(resolution_x=H, resolution_y=W, num_views=num_imgs, stides=1) + cameras = construct_camera(K, cam_list) + + if render_mode == 'novel_pose': # if poses are changed, cameras will be fixed + intrics = torch.Tensor([K[0,0],K[1,1], 256, 256]).reshape(-1) + model.decoder.renderer.image_size = [512, 512] + + assert cameras.shape[-1] == 20 + cameras[:, :4] = intrics + cameras = cameras[cam_idx:cam_idx+1] + num_imgs = smpl_params.shape[0] + cameras = cameras.repeat(num_imgs, 1) + cameras = cameras[:, None, :] # length of the pose sequences + print("modify the render images's resolution into 512x512 ") + + elif render_mode in ['reconstruct']: # using reference smplify's smplx and camera + cameras = torch.concat([intri_rec.reshape(-1,4), RT_rec.reshape(-1, 16)], dim=1) + # H, W = int(intri_rec[2] * 2), int(intri_rec[3] * 2) + model.decoder.renderer.image_size = [W_rec, H_rec]; print(f"modify the render images's resolution into {H_rec}x{W_rec}") + cameras = cameras.reshape(1,1,20).expand(num_imgs,1,-1) + cameras = cameras.cuda() + + elif render_mode == 'novel_pose_A': + model.decoder.renderer.image_size = [W, H] + cameras = cameras[0].reshape(1,1,20).expand(num_imgs,1,-1) + + elif if_use_video_cam: # for the animation with motion-x + cameras = construct_camera_from_motionx(smplx_path) + H, W = 2*cameras[0, 0, [3]].int().item(), 2*cameras[0,0, [2]].int().item() + model.decoder.renderer.image_size = [W, H]; print(f"modify the render images's resolution into {H}x{W}") + # model.decoder.renderer = + + ############################################################################### + # Stage 2: Reconstruction. + ############################################################################### + + sample = image_bs[[0]] # N, 3, H, W, + # if if_use_dataset: + # sample = rearrange(sample, 'b h w c -> b c h w') # N, 3, H, W, + + image_path_idx = os.path.join(args.output_folders['ref'], f'{name}_ref.jpg') + torchvision.utils.save_image( sample[0], image_path_idx) + + with torch.no_grad(): + # get latents + code = model.forward_image_to_uv(sample, is_training=False, low_ram=args.low_ram) + torch.cuda.empty_cache() + + with torch.no_grad(): + output_list = [] + num_imgs_batch = 3 + total_frames = min(smpl_params.shape[0],300) + res_uv, res_points = None, None + for i in tqdm(range(0, total_frames, num_imgs_batch)): + if i+num_imgs_batch > total_frames: + num_imgs_batch = total_frames - i + code_bt = code.expand(num_imgs_batch, -1, -1, -1) + # cameras_bt = cameras.expand(num_imgs_batch, -1, -1) + cameras_bt = cameras[i:i+num_imgs_batch] + + if render_mode == "novel_pose" or res_uv is None: + del res_uv + del res_points + res_uv = model.decoder._decode_feature(code_bt, low_ram=args.low_ram) # Decouple UV attributes + res_points = model.decoder._sample_feature(res_uv) # Sampling + # Animate + res_def_points = model.decoder.deform_pcd(res_points, smpl_params[i:i+num_imgs_batch].to(code_bt.dtype), zeros_hands_off=True, value=0.02) + output = model.decoder.forward_render(res_def_points, cameras_bt.to(code_bt.dtype), num_imgs=1) + image = output["image"][:, 0].cpu().to(torch.float32) + + print("output shape ", output["image"][:, 0].shape) + output_list.append(image) # [:, 0] stands to get the all scenes (poses) + + del res_def_points + del output + del image + torch.cuda.empty_cache() + + output = torch.concatenate(output_list, 0) + frames = rearrange(output, "b h w c -> b c h w")#.cpu().numpy() + + video_path_idx = os.path.join(args.output_folders['video'], f'{name}.mp4') + + save_video( + frames[:,:4,...].to(torch.float32), + video_path_idx, + ) + image_plist.append(frames) + print("saving into ", video_path_idx) + return image_plist + +def setup_directories(base_path, config_name): + """Create output directories for results""" + dirs = { + 'image': os.path.join(base_path, config_name, 'images'), + 'video': os.path.join(base_path, config_name), + 'ref': os.path.join(base_path, config_name) + } + for path in dirs.values(): + os.makedirs(path, exist_ok=True) + return dirs + +def main(): + """Main execution function""" + # Parse arguments and set random seed + args = parse_args() + + args.config = "configs/idol_v0.yaml" + args.resume_path = "work_dirs/ckpt/model.ckpt" + + config = OmegaConf.load(args.config) + config_name = os.path.basename(args.config).replace('.yaml', '') + model_config = config.model + + resume_path = args.resume_path + # Initialize model + # model = instantiate_from_config(model_config) + model_class = get_class_from_config(model_config) + model = model_class.load_from_checkpoint(resume_path, **config.model.params) + # model.encoder = model.encoder.to(torch.bfloat16) ; print("moving encoder to bf16") + model = model.to(device) + model = model.eval() + + # Setup input paths + img_paths_list = ['work_dirs/demo_data/4.jpg'] + smplx_ref_path_list = ['work_dirs/demo_data/4.json'] + smplx_path_driven_list = ['work_dirs/demo_data/Ways_to_Catch_360_clip1.json'] + # smplx_path_driven_list = ['work_dirs/demo_data/finedance-5-144.npy.npy'] + + # Setup output directories + # args.output_path = "./test/" + # args.render_mode = 'reconstruct' # 'novel_pose_A' #'reconstruct' #'novel_pose' + + # make output directories + args.output_folders = setup_directories(args.output_path, config_name) + + # Process data + image_plist = process_data_on_gpu( + args, + model, 0, + img_paths_list, + smplx_ref_path_list, + smplx_path_driven_list + ) + + return image_plist + +if __name__ == "__main__": + main() + +