Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from packaging import version as pver | |
| from cameractrl.pipelines.pipeline_animation import StableVideoDiffusionPipelinePoseCond | |
| from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond | |
| from cameractrl.models.pose_adaptor import CameraPoseEncoder | |
| from cameractrl.utils.util import save_videos_grid | |
| class Camera(object): | |
| def __init__(self, entry): | |
| fx, fy, cx, cy = entry[1:5] | |
| self.fx = fx | |
| self.fy = fy | |
| self.cx = cx | |
| self.cy = cy | |
| w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
| w2c_mat_4x4 = np.eye(4) | |
| w2c_mat_4x4[:3, :] = w2c_mat | |
| self.w2c_mat = w2c_mat_4x4 | |
| self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
| def setup_for_distributed(is_master): | |
| """ | |
| This function disables printing when not in master process | |
| """ | |
| import builtins as __builtin__ | |
| builtin_print = __builtin__.print | |
| def print(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| if is_master or force: | |
| builtin_print(*args, **kwargs) | |
| __builtin__.print = print | |
| def custom_meshgrid(*args): | |
| # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
| if pver.parse(torch.__version__) < pver.parse('1.10'): | |
| return torch.meshgrid(*args) | |
| else: | |
| return torch.meshgrid(*args, indexing='ij') | |
| def get_relative_pose(cam_params, zero_first_frame_scale): | |
| abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
| abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
| source_cam_c2w = abs_c2ws[0] | |
| if zero_first_frame_scale: | |
| cam_to_origin = 0 | |
| else: | |
| cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) | |
| target_cam_c2w = np.array([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, -cam_to_origin], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| abs2rel = target_cam_c2w @ abs_w2cs[0] | |
| ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
| ret_poses = np.array(ret_poses, dtype=np.float32) | |
| return ret_poses | |
| def ray_condition(K, c2w, H, W, device): | |
| # c2w: B, V, 4, 4 | |
| # K: B, V, 4 | |
| B = K.shape[0] | |
| j, i = custom_meshgrid( | |
| torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
| torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
| ) | |
| i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] | |
| j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] | |
| fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
| zs = torch.ones_like(i) # [B, HxW] | |
| xs = (i - cx) / fx * zs | |
| ys = (j - cy) / fy * zs | |
| zs = zs.expand_as(ys) | |
| directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
| directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
| rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW | |
| rays_o = c2w[..., :3, 3] # B, V, 3 | |
| rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW | |
| # c2w @ dirctions | |
| rays_dxo = torch.linalg.cross(rays_o, rays_d) | |
| plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
| plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
| return plucker | |
| def get_pipeline(ori_model_path, unet_subfolder, down_block_types, up_block_types, pose_encoder_kwargs, | |
| attention_processor_kwargs, pose_adaptor_ckpt, enable_xformers, device): | |
| noise_scheduler = EulerDiscreteScheduler.from_pretrained(ori_model_path, subfolder="scheduler") | |
| feature_extractor = CLIPImageProcessor.from_pretrained(ori_model_path, subfolder="feature_extractor") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(ori_model_path, subfolder="image_encoder") | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained(ori_model_path, subfolder="vae") | |
| unet = UNetSpatioTemporalConditionModelPoseCond.from_pretrained(ori_model_path, | |
| subfolder=unet_subfolder, | |
| down_block_types=down_block_types, | |
| up_block_types=up_block_types) | |
| pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs) | |
| print("Setting the attention processors") | |
| unet.set_pose_cond_attn_processor(enable_xformers=(enable_xformers and is_xformers_available()), **attention_processor_kwargs) | |
| print(f"Loading weights of camera encoder and attention processor from {pose_adaptor_ckpt}") | |
| ckpt_dict = torch.load(pose_adaptor_ckpt, map_location=unet.device) | |
| pose_encoder_state_dict = ckpt_dict['pose_encoder_state_dict'] | |
| pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict) | |
| assert len(pose_encoder_m) == 0 and len(pose_encoder_u) == 0 | |
| attention_processor_state_dict = ckpt_dict['attention_processor_state_dict'] | |
| _, attention_processor_u = unet.load_state_dict(attention_processor_state_dict, strict=False) | |
| assert len(attention_processor_u) == 0 | |
| print("Loading done") | |
| vae.set_attn_processor(AttnProcessor2_0()) | |
| vae.to(device) | |
| image_encoder.to(device) | |
| unet.to(device) | |
| pipeline = StableVideoDiffusionPipelinePoseCond( | |
| vae=vae, | |
| image_encoder=image_encoder, | |
| unet=unet, | |
| scheduler=noise_scheduler, | |
| feature_extractor=feature_extractor, | |
| pose_encoder=pose_encoder | |
| ) | |
| pipeline = pipeline.to(device) | |
| return pipeline | |
| def main(args): | |
| os.makedirs(os.path.join(args.out_root, 'generated_videos'), exist_ok=True) | |
| os.makedirs(os.path.join(args.out_root, 'reference_images'), exist_ok=True) | |
| rank = args.local_rank | |
| setup_for_distributed(rank == 0) | |
| gpu_id = rank % torch.cuda.device_count() | |
| model_configs = OmegaConf.load(args.model_config) | |
| device = f"cuda:{gpu_id}" | |
| print(f'Constructing pipeline') | |
| pipeline = get_pipeline(args.ori_model_path, model_configs['unet_subfolder'], model_configs['down_block_types'], | |
| model_configs['up_block_types'], model_configs['pose_encoder_kwargs'], | |
| model_configs['attention_processor_kwargs'], args.pose_adaptor_ckpt, args.enable_xformers, device) | |
| print('Done') | |
| print('Loading K, R, t matrix') | |
| with open(args.trajectory_file, 'r') as f: | |
| poses = f.readlines() | |
| poses = [pose.strip().split(' ') for pose in poses[1:]] | |
| cam_params = [[float(x) for x in pose] for pose in poses] | |
| cam_params = [Camera(cam_param) for cam_param in cam_params] | |
| sample_wh_ratio = args.image_width / args.image_height | |
| pose_wh_ratio = args.original_pose_width / args.original_pose_height | |
| if pose_wh_ratio > sample_wh_ratio: | |
| resized_ori_w = args.image_height * pose_wh_ratio | |
| for cam_param in cam_params: | |
| cam_param.fx = resized_ori_w * cam_param.fx / args.image_width | |
| else: | |
| resized_ori_h = args.image_width / pose_wh_ratio | |
| for cam_param in cam_params: | |
| cam_param.fy = resized_ori_h * cam_param.fy / args.image_height | |
| intrinsic = np.asarray([[cam_param.fx * args.image_width, | |
| cam_param.fy * args.image_height, | |
| cam_param.cx * args.image_width, | |
| cam_param.cy * args.image_height] | |
| for cam_param in cam_params], dtype=np.float32) | |
| K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] | |
| c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True) | |
| c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] | |
| plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu') # b f h w 6 | |
| plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device) | |
| prompt_dict = json.load(open(args.prompt_file, 'r')) | |
| prompt_images = prompt_dict['image_paths'] | |
| prompt_captions = prompt_dict['captions'] | |
| N = int(len(prompt_images) // args.n_procs) | |
| remainder = int(len(prompt_images) % args.n_procs) | |
| prompts_per_gpu = [N + 1 if gpu_id < remainder else N for gpu_id in range(args.n_procs)] | |
| low_idx = sum(prompts_per_gpu[:gpu_id]) | |
| high_idx = low_idx + prompts_per_gpu[gpu_id] | |
| prompt_images = prompt_images[low_idx: high_idx] | |
| prompt_captions = prompt_captions[low_idx: high_idx] | |
| print(f"rank {rank} / {torch.cuda.device_count()}, number of prompts: {len(prompt_images)}") | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(42) | |
| for prompt_image, prompt_caption in tqdm(zip(prompt_images, prompt_captions)): | |
| save_name = "_".join(prompt_caption.split(" ")) | |
| condition_image = Image.open(prompt_image) | |
| with torch.no_grad(): | |
| sample = pipeline( | |
| image=condition_image, | |
| pose_embedding=plucker_embedding, | |
| height=args.image_height, | |
| width=args.image_width, | |
| num_frames=args.num_frames, | |
| num_inference_steps=args.num_inference_steps, | |
| min_guidance_scale=args.min_guidance_scale, | |
| max_guidance_scale=args.max_guidance_scale, | |
| do_image_process=True, | |
| generator=generator, | |
| output_type='pt' | |
| ).frames[0].transpose(0, 1).cpu() # [3, f, h, w] 0-1 | |
| resized_condition_image = condition_image.resize((args.image_width, args.image_height)) | |
| save_videos_grid(sample[None], f"{os.path.join(args.out_root, 'generated_videos')}/{save_name}.mp4", rescale=False) | |
| resized_condition_image.save(os.path.join(args.out_root, 'reference_images', f'{save_name}.png')) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--out_root", type=str) | |
| parser.add_argument("--image_height", type=int, default=320) | |
| parser.add_argument("--image_width", type=int, default=576) | |
| parser.add_argument("--num_frames", type=int, default=14) | |
| parser.add_argument("--ori_model_path", type=str) | |
| parser.add_argument("--unet_subfolder", type=str, default='unet') | |
| parser.add_argument("--enable_xformers", action='store_true') | |
| parser.add_argument("--pose_adaptor_ckpt", default=None) | |
| parser.add_argument("--num_inference_steps", type=int, default=25) | |
| parser.add_argument("--min_guidance_scale", type=float, default=1.0) | |
| parser.add_argument("--max_guidance_scale", type=float, default=3.0) | |
| parser.add_argument("--prompt_file", required=True, help='prompts path, json or txt') | |
| parser.add_argument("--trajectory_file", required=True) | |
| parser.add_argument("--original_pose_width", type=int, default=1280) | |
| parser.add_argument("--original_pose_height", type=int, default=720) | |
| parser.add_argument("--model_config", required=True) | |
| parser.add_argument("--n_procs", type=int, default=8) | |
| # DDP args | |
| parser.add_argument("--world_size", default=1, type=int, | |
| help="number of the distributed processes.") | |
| parser.add_argument('--local-rank', type=int, default=-1, | |
| help='Replica rank on the current node. This field is required ' | |
| 'by `torch.distributed.launch`.') | |
| args = parser.parse_args() | |
| main(args) | |