Zhizhou Zhong commited on
Commit
6425819
·
unverified ·
1 Parent(s): b2750c1

feat: real-time infer (#286)

Browse files

* feat: realtime infer

* cchore: infer script

README.md CHANGED
@@ -130,9 +130,8 @@ https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
130
  - [x] codes for real-time inference.
131
  - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
132
  - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
 
133
  - [ ] training and dataloader code (Expected completion on 04/04/2025).
134
- - [ ] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
135
-
136
 
137
 
138
  # Getting Started
@@ -220,21 +219,52 @@ We provide inference scripts for both versions of MuseTalk:
220
 
221
  #### MuseTalk 1.5 (Recommended)
222
  ```bash
223
- sh inference.sh v1.5
 
 
 
 
 
 
 
224
  ```
225
- This inference script supports both MuseTalk 1.5 and 1.0 models:
 
226
  - For MuseTalk 1.5: Use the command above with the V1.5 model path
227
  - For MuseTalk 1.0: Use the same script but point to the V1.0 model path
228
 
229
- configs/inference/test.yaml is the path to the inference configuration file, including video_path and audio_path.
230
- The video_path should be either a video file, an image file or a directory of images.
 
231
 
232
- #### MuseTalk 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  ```bash
234
- sh inference.sh v1.0
235
  ```
236
- You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
237
- <details close>
238
  ## TestCases For 1.0
239
  <table class="center">
240
  <tr style="font-weight: bolder;text-align:center;">
@@ -332,39 +362,11 @@ python -m scripts.inference --inference_config configs/inference/test.yaml --bbo
332
  ```
333
  :pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
334
 
335
- </details>
336
-
337
 
338
  #### Combining MuseV and MuseTalk
339
 
340
  As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
341
 
342
- #### Real-time inference
343
-
344
- <details close>
345
- Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
346
-
347
- ```
348
- python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
349
- ```
350
- configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
351
-
352
- 1. Set `preparation` to `True` in `realtime.yaml` to prepare the materials for a new `avatar`. (If the `bbox_shift` has changed, you also need to re-prepare the materials.)
353
- 1. After that, the `avatar` will use an audio clip selected from `audio_clips` to generate video.
354
- ```
355
- Inferring using: data/audio/yongen.wav
356
- ```
357
- 1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
358
- 1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
359
-
360
- ##### Note for Real-time inference
361
- 1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
362
- 1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
363
- ```
364
- python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
365
- ```
366
- </details>
367
-
368
  # Acknowledgement
369
  1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
370
  1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
 
130
  - [x] codes for real-time inference.
131
  - [x] [technical report](https://arxiv.org/abs/2410.10122v2).
132
  - [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
133
+ - [x] realtime inference code for 1.5 version (Note: MuseTalk 1.5 has the same computation time as 1.0 and supports real-time inference. The code implementation will be released soon).
134
  - [ ] training and dataloader code (Expected completion on 04/04/2025).
 
 
135
 
136
 
137
  # Getting Started
 
219
 
220
  #### MuseTalk 1.5 (Recommended)
221
  ```bash
222
+ # Run MuseTalk 1.5 inference
223
+ sh inference.sh v1.5 normal
224
+ ```
225
+
226
+ #### MuseTalk 1.0
227
+ ```bash
228
+ # Run MuseTalk 1.0 inference
229
+ sh inference.sh v1.0 normal
230
  ```
231
+
232
+ The inference script supports both MuseTalk 1.5 and 1.0 models:
233
  - For MuseTalk 1.5: Use the command above with the V1.5 model path
234
  - For MuseTalk 1.0: Use the same script but point to the V1.0 model path
235
 
236
+ The configuration file `configs/inference/test.yaml` contains the inference settings, including:
237
+ - `video_path`: Path to the input video, image file, or directory of images
238
+ - `audio_path`: Path to the input audio file
239
 
240
+ Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
241
+
242
+ #### Real-time Inference
243
+ For real-time inference, use the following command:
244
+ ```bash
245
+ # Run real-time inference
246
+ sh inference.sh v1.5 realtime # For MuseTalk 1.5
247
+ # or
248
+ sh inference.sh v1.0 realtime # For MuseTalk 1.0
249
+ ```
250
+
251
+ The real-time inference configuration is in `configs/inference/realtime.yaml`, which includes:
252
+ - `preparation`: Set to `True` for new avatar preparation
253
+ - `video_path`: Path to the input video
254
+ - `bbox_shift`: Adjustable parameter for mouth region control
255
+ - `audio_clips`: List of audio clips for generation
256
+
257
+ Important notes for real-time inference:
258
+ 1. Set `preparation` to `True` when processing a new avatar
259
+ 2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
260
+ 3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
261
+ 4. Set `preparation` to `False` for generating more videos with the same avatar
262
+
263
+ For faster generation without saving images, you can use:
264
  ```bash
265
+ python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
266
  ```
267
+
 
268
  ## TestCases For 1.0
269
  <table class="center">
270
  <tr style="font-weight: bolder;text-align:center;">
 
362
  ```
363
  :pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
364
 
 
 
365
 
366
  #### Combining MuseV and MuseTalk
367
 
368
  As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  # Acknowledgement
371
  1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
372
  1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
configs/inference/realtime.yaml CHANGED
@@ -1,10 +1,10 @@
1
  avator_1:
2
- preparation: False
3
  bbox_shift: 5
4
- video_path: "data/video/sun.mp4"
5
  audio_clips:
6
  audio_0: "data/audio/yongen.wav"
7
- audio_1: "data/audio/sun.wav"
8
 
9
 
10
 
 
1
  avator_1:
2
+ preparation: True # your can set it to False if you want to use the existing avator, it will save time
3
  bbox_shift: 5
4
+ video_path: "data/video/yongen.mp4"
5
  audio_clips:
6
  audio_0: "data/audio/yongen.wav"
7
+ audio_1: "data/audio/eng.wav"
8
 
9
 
10
 
configs/inference/test.yaml CHANGED
@@ -3,8 +3,8 @@ task_0:
3
  audio_path: "data/audio/yongen.wav"
4
 
5
  task_1:
6
- video_path: "data/video/sun.mp4"
7
- audio_path: "data/audio/sun.wav"
8
  bbox_shift: -7
9
 
10
 
 
3
  audio_path: "data/audio/yongen.wav"
4
 
5
  task_1:
6
+ video_path: "data/video/yongen.mp4"
7
+ audio_path: "data/audio/eng.wav"
8
  bbox_shift: -7
9
 
10
 
data/audio/eng.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:654dcbce843d70451d1123f7649a58bee11bb9dec9a7e835c05b1e367efb2078
3
+ size 1920078
inference.sh CHANGED
@@ -1,46 +1,72 @@
1
  #!/bin/bash
2
 
3
- # This script runs inference based on the version specified by the user.
4
  # Usage:
5
- # To run v1.0 inference: sh inference.sh v1.0
6
- # To run v1.5 inference: sh inference.sh v1.5
7
 
8
  # Check if the correct number of arguments is provided
9
- if [ "$#" -ne 1 ]; then
10
- echo "Usage: $0 <version>"
11
- echo "Example: $0 v1.0 or $0 v1.5"
12
  exit 1
13
  fi
14
 
15
- # Get the version from the user input
16
  version=$1
17
- config_path="./configs/inference/test.yaml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Define the model paths based on the version
20
  if [ "$version" = "v1.0" ]; then
21
  model_dir="./models/musetalk"
22
  unet_model_path="$model_dir/pytorch_model.bin"
23
  unet_config="$model_dir/musetalk.json"
 
24
  elif [ "$version" = "v1.5" ]; then
25
  model_dir="./models/musetalkV15"
26
  unet_model_path="$model_dir/unet.pth"
27
  unet_config="$model_dir/musetalk.json"
 
28
  else
29
  echo "Invalid version specified. Please use v1.0 or v1.5."
30
  exit 1
31
  fi
32
 
33
- # Run inference based on the version
34
- if [ "$version" = "v1.0" ]; then
35
- python3 -m scripts.inference \
36
- --inference_config "$config_path" \
37
- --result_dir "./results/test" \
38
- --unet_model_path "$unet_model_path" \
39
- --unet_config "$unet_config"
40
- elif [ "$version" = "v1.5" ]; then
41
- python3 -m scripts.inference_alpha \
42
- --inference_config "$config_path" \
43
- --result_dir "./results/test" \
44
- --unet_model_path "$unet_model_path" \
45
- --unet_config "$unet_config"
46
- fi
 
 
 
 
 
 
 
 
 
 
1
  #!/bin/bash
2
 
3
+ # This script runs inference based on the version and mode specified by the user.
4
  # Usage:
5
+ # To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
6
+ # To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
7
 
8
  # Check if the correct number of arguments is provided
9
+ if [ "$#" -ne 2 ]; then
10
+ echo "Usage: $0 <version> <mode>"
11
+ echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
12
  exit 1
13
  fi
14
 
15
+ # Get the version and mode from the user input
16
  version=$1
17
+ mode=$2
18
+
19
+ # Validate mode
20
+ if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
21
+ echo "Invalid mode specified. Please use 'normal' or 'realtime'."
22
+ exit 1
23
+ fi
24
+
25
+ # Set config path based on mode
26
+ if [ "$mode" = "normal" ]; then
27
+ config_path="./configs/inference/test.yaml"
28
+ result_dir="./results/test"
29
+ else
30
+ config_path="./configs/inference/realtime.yaml"
31
+ result_dir="./results/realtime"
32
+ fi
33
 
34
  # Define the model paths based on the version
35
  if [ "$version" = "v1.0" ]; then
36
  model_dir="./models/musetalk"
37
  unet_model_path="$model_dir/pytorch_model.bin"
38
  unet_config="$model_dir/musetalk.json"
39
+ version_arg="v1"
40
  elif [ "$version" = "v1.5" ]; then
41
  model_dir="./models/musetalkV15"
42
  unet_model_path="$model_dir/unet.pth"
43
  unet_config="$model_dir/musetalk.json"
44
+ version_arg="v15"
45
  else
46
  echo "Invalid version specified. Please use v1.0 or v1.5."
47
  exit 1
48
  fi
49
 
50
+ # Set script name based on mode
51
+ if [ "$mode" = "normal" ]; then
52
+ script_name="scripts.inference"
53
+ else
54
+ script_name="scripts.realtime_inference"
55
+ fi
56
+
57
+ # Base command arguments
58
+ cmd_args="--inference_config $config_path \
59
+ --result_dir $result_dir \
60
+ --unet_model_path $unet_model_path \
61
+ --unet_config $unet_config \
62
+ --version $version_arg \
63
+
64
+ # Add realtime-specific arguments if in realtime mode
65
+ if [ "$mode" = "realtime" ]; then
66
+ cmd_args="$cmd_args \
67
+ --fps 25 \
68
+ --version $version_arg \
69
+ fi
70
+
71
+ # Run inference
72
+ python3 -m $script_name $cmd_args
musetalk/utils/audio_processor.py CHANGED
@@ -11,7 +11,7 @@ class AudioProcessor:
11
  def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
12
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
13
 
14
- def get_audio_feature(self, wav_path, start_index=0):
15
  if not os.path.exists(wav_path):
16
  return None
17
  librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
@@ -27,6 +27,8 @@ class AudioProcessor:
27
  return_tensors="pt",
28
  sampling_rate=sampling_rate
29
  ).input_features
 
 
30
  features.append(audio_feature)
31
 
32
  return features, len(librosa_output)
 
11
  def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
12
  self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
13
 
14
+ def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
15
  if not os.path.exists(wav_path):
16
  return None
17
  librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
 
27
  return_tensors="pt",
28
  sampling_rate=sampling_rate
29
  ).input_features
30
+ if weight_dtype is not None:
31
+ audio_feature = audio_feature.to(dtype=weight_dtype)
32
  features.append(audio_feature)
33
 
34
  return features, len(librosa_output)
musetalk/utils/blending.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import cv2
4
  import copy
5
 
 
6
  def get_crop_box(box, expand):
7
  x, y, x1, y1 = box
8
  x_c, y_c = (x+x1)//2, (y+y1)//2
@@ -11,7 +12,8 @@ def get_crop_box(box, expand):
11
  crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
12
  return crop_box, s
13
 
14
- def face_seg(image, mode="jaw", fp=None):
 
15
  """
16
  对图像进行面部解析,生成面部区域的掩码。
17
 
@@ -86,14 +88,12 @@ def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode=
86
 
87
  body.paste(face_large, crop_box[:2], mask_image)
88
 
89
- # 不用掩码,完全用infer
90
- #face_large.save("debug/checkpoint_6_face_large.png")
91
-
92
  body = np.array(body) # 将 PIL 图像转换回 numpy 数组
93
 
94
  return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
95
 
96
- def get_image_blending(image,face,face_box,mask_array,crop_box):
 
97
  body = Image.fromarray(image[:,:,::-1])
98
  face = Image.fromarray(face[:,:,::-1])
99
 
@@ -108,7 +108,8 @@ def get_image_blending(image,face,face_box,mask_array,crop_box):
108
  body = np.array(body)
109
  return body[:,:,::-1]
110
 
111
- def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=1.2):
 
112
  body = Image.fromarray(image[:,:,::-1])
113
 
114
  x, y, x1, y1 = face_box
@@ -119,7 +120,7 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
119
  face_large = body.crop(crop_box)
120
  ori_shape = face_large.size
121
 
122
- mask_image = face_seg(face_large)
123
  mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
124
  mask_image = Image.new('L', ori_shape, 0)
125
  mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
@@ -132,4 +133,4 @@ def get_image_prepare_material(image,face_box,upper_boundary_ratio = 0.5,expand=
132
 
133
  blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
134
  mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
135
- return mask_array,crop_box
 
3
  import cv2
4
  import copy
5
 
6
+
7
  def get_crop_box(box, expand):
8
  x, y, x1, y1 = box
9
  x_c, y_c = (x+x1)//2, (y+y1)//2
 
12
  crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
13
  return crop_box, s
14
 
15
+
16
+ def face_seg(image, mode="raw", fp=None):
17
  """
18
  对图像进行面部解析,生成面部区域的掩码。
19
 
 
88
 
89
  body.paste(face_large, crop_box[:2], mask_image)
90
 
 
 
 
91
  body = np.array(body) # 将 PIL 图像转换回 numpy 数组
92
 
93
  return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
94
 
95
+
96
+ def get_image_blending(image, face, face_box, mask_array, crop_box):
97
  body = Image.fromarray(image[:,:,::-1])
98
  face = Image.fromarray(face[:,:,::-1])
99
 
 
108
  body = np.array(body)
109
  return body[:,:,::-1]
110
 
111
+
112
+ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
113
  body = Image.fromarray(image[:,:,::-1])
114
 
115
  x, y, x1, y1 = face_box
 
120
  face_large = body.crop(crop_box)
121
  ori_shape = face_large.size
122
 
123
+ mask_image = face_seg(face_large, mode=mode, fp=fp)
124
  mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
125
  mask_image = Image.new('L', ori_shape, 0)
126
  mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
 
133
 
134
  blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
135
  mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
136
+ return mask_array, crop_box
musetalk/utils/face_parsing/__init__.py CHANGED
@@ -74,7 +74,7 @@ class FaceParsing():
74
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
75
  ])
76
 
77
- def __call__(self, image, size=(512, 512), mode="jaw"):
78
  if isinstance(image, str):
79
  image = Image.open(image)
80
 
 
74
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
75
  ])
76
 
77
+ def __call__(self, image, size=(512, 512), mode="raw"):
78
  if isinstance(image, str):
79
  image = Image.open(image)
80
 
scripts/inference.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import cv2
 
3
  import copy
4
- import glob
5
  import torch
 
6
  import shutil
7
  import pickle
8
  import argparse
@@ -17,18 +18,16 @@ from musetalk.utils.audio_processor import AudioProcessor
17
  from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
18
  from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
19
 
20
-
21
-
22
  @torch.no_grad()
23
  def main(args):
24
  # Configure ffmpeg path
25
  if args.ffmpeg_path not in os.getenv('PATH'):
26
  print("Adding ffmpeg to PATH")
27
  os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
28
-
29
  # Set computing device
30
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
31
-
32
  # Load model weights
33
  vae, unet, pe = load_all_model(
34
  unet_model_path=args.unet_model_path,
@@ -37,164 +36,229 @@ def main(args):
37
  device=device
38
  )
39
  timesteps = torch.tensor([0], device=device)
40
-
41
-
42
- if args.use_float16 is True:
43
  pe = pe.half()
44
  vae.vae = vae.vae.half()
45
  unet.model = unet.model.half()
 
 
 
 
 
46
 
47
- # Initialize audio processor and Whisper model
48
  audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
49
  weight_dtype = unet.model.dtype
50
  whisper = WhisperModel.from_pretrained(args.whisper_dir)
51
  whisper = whisper.to(device=device, dtype=weight_dtype).eval()
52
  whisper.requires_grad_(False)
53
 
54
- # Initialize face parser
55
- fp = FaceParsing()
 
 
 
 
 
 
56
 
 
57
  inference_config = OmegaConf.load(args.inference_config)
58
- print(inference_config)
 
 
59
  for task_id in inference_config:
60
- video_path = inference_config[task_id]["video_path"]
61
- audio_path = inference_config[task_id]["audio_path"]
62
- bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
63
-
64
- input_basename = os.path.basename(video_path).split('.')[0]
65
- audio_basename = os.path.basename(audio_path).split('.')[0]
66
- output_basename = f"{input_basename}_{audio_basename}"
67
- result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
68
- crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
69
- os.makedirs(result_img_save_path,exist_ok =True)
70
-
71
- if args.output_vid_name is None:
72
- output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
73
- else:
74
- output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
75
- ############################################## extract frames from source video ##############################################
76
- if get_file_type(video_path)=="video":
77
- save_dir_full = os.path.join(args.result_dir, input_basename)
78
- os.makedirs(save_dir_full,exist_ok = True)
79
- cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
80
- os.system(cmd)
81
- input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
82
- fps = get_video_fps(video_path)
83
- elif get_file_type(video_path)=="image":
84
- input_img_list = [video_path, ]
85
- fps = args.fps
86
- elif os.path.isdir(video_path): # input img folder
87
- input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
88
- input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
89
- fps = args.fps
90
- else:
91
- raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- ############################################## extract audio feature ##############################################
94
- # Extract audio features
95
- whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
96
- whisper_chunks = audio_processor.get_whisper_chunk(
97
- whisper_input_features,
98
- device,
99
- weight_dtype,
100
- whisper,
101
- librosa_length,
102
- fps=fps,
103
- audio_padding_length_left=args.audio_padding_length_left,
104
- audio_padding_length_right=args.audio_padding_length_right,
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- ############################################## preprocess input image ##############################################
108
- if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
109
- print("using extracted coordinates")
110
- with open(crop_coord_save_path,'rb') as f:
111
- coord_list = pickle.load(f)
112
- frame_list = read_imgs(input_img_list)
113
- else:
114
- print("extracting landmarks...time consuming")
115
- coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
116
- with open(crop_coord_save_path, 'wb') as f:
117
- pickle.dump(coord_list, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- i = 0
120
- input_latent_list = []
121
- for bbox, frame in zip(coord_list, frame_list):
122
- if bbox == coord_placeholder:
123
- continue
124
- x1, y1, x2, y2 = bbox
125
- crop_frame = frame[y1:y2, x1:x2]
126
- crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
127
- latents = vae.get_latents_for_unet(crop_frame)
128
- input_latent_list.append(latents)
129
-
130
- # to smooth the first and the last frame
131
- frame_list_cycle = frame_list + frame_list[::-1]
132
- coord_list_cycle = coord_list + coord_list[::-1]
133
- input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
134
- ############################################## inference batch by batch ##############################################
135
- print("start inference")
136
- video_num = len(whisper_chunks)
137
- batch_size = args.batch_size
138
- gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
139
- res_frame_list = []
140
- for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
141
- audio_feature_batch = pe(whisper_batch)
142
- latent_batch = latent_batch.to(dtype=unet.model.dtype)
143
-
144
- pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
145
- recon = vae.decode_latents(pred_latents)
146
- for res_frame in recon:
147
- res_frame_list.append(res_frame)
148
 
149
- ############################################## pad to full image ##############################################
150
- print("pad talking image to original video")
151
- for i, res_frame in enumerate(tqdm(res_frame_list)):
152
- bbox = coord_list_cycle[i%(len(coord_list_cycle))]
153
- ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
154
- x1, y1, x2, y2 = bbox
155
- try:
156
- res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
157
- except:
158
- continue
159
-
160
- # Merge results
161
- combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
162
- cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
163
 
164
- cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
165
- print(cmd_img2video)
166
- os.system(cmd_img2video)
167
-
168
- cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
169
- print(cmd_combine_audio)
170
- os.system(cmd_combine_audio)
171
-
172
- os.remove("temp.mp4")
173
- shutil.rmtree(result_img_save_path)
174
- print(f"result is save to {output_vid_name}")
 
 
 
 
 
 
 
 
 
 
175
 
176
  if __name__ == "__main__":
177
  parser = argparse.ArgumentParser()
178
  parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
179
- parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
180
- parser.add_argument("--bbox_shift", type=int, default=0)
181
- parser.add_argument("--result_dir", default='./results', help="path to output")
182
  parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
183
- parser.add_argument("--batch_size", type=int, default=8)
184
- parser.add_argument("--output_vid_name", type=str, default=None)
185
- parser.add_argument("--use_saved_coord",
186
- action="store_true",
187
- help='use saved coordinate to save time')
188
- parser.add_argument("--use_float16",
189
- action="store_true",
190
- help="Whether use float16 to speed up inference",
191
- )
192
- parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
193
- parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
194
  parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
195
  parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
 
196
  parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
 
 
 
 
 
197
  parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
198
  parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
 
 
 
 
 
 
 
 
 
199
  args = parser.parse_args()
200
  main(args)
 
1
  import os
2
  import cv2
3
+ import math
4
  import copy
 
5
  import torch
6
+ import glob
7
  import shutil
8
  import pickle
9
  import argparse
 
18
  from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
19
  from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
20
 
 
 
21
  @torch.no_grad()
22
  def main(args):
23
  # Configure ffmpeg path
24
  if args.ffmpeg_path not in os.getenv('PATH'):
25
  print("Adding ffmpeg to PATH")
26
  os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
27
+
28
  # Set computing device
29
  device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
30
+
31
  # Load model weights
32
  vae, unet, pe = load_all_model(
33
  unet_model_path=args.unet_model_path,
 
36
  device=device
37
  )
38
  timesteps = torch.tensor([0], device=device)
39
+
40
+ # Convert models to half precision if float16 is enabled
41
+ if args.use_float16:
42
  pe = pe.half()
43
  vae.vae = vae.vae.half()
44
  unet.model = unet.model.half()
45
+
46
+ # Move models to specified device
47
+ pe = pe.to(device)
48
+ vae.vae = vae.vae.to(device)
49
+ unet.model = unet.model.to(device)
50
 
51
+ # Initialize audio processor and Whisper model
52
  audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
53
  weight_dtype = unet.model.dtype
54
  whisper = WhisperModel.from_pretrained(args.whisper_dir)
55
  whisper = whisper.to(device=device, dtype=weight_dtype).eval()
56
  whisper.requires_grad_(False)
57
 
58
+ # Initialize face parser with configurable parameters based on version
59
+ if args.version == "v15":
60
+ fp = FaceParsing(
61
+ left_cheek_width=args.left_cheek_width,
62
+ right_cheek_width=args.right_cheek_width
63
+ )
64
+ else: # v1
65
+ fp = FaceParsing()
66
 
67
+ # Load inference configuration
68
  inference_config = OmegaConf.load(args.inference_config)
69
+ print("Loaded inference config:", inference_config)
70
+
71
+ # Process each task
72
  for task_id in inference_config:
73
+ try:
74
+ # Get task configuration
75
+ video_path = inference_config[task_id]["video_path"]
76
+ audio_path = inference_config[task_id]["audio_path"]
77
+ if "result_name" in inference_config[task_id]:
78
+ args.output_vid_name = inference_config[task_id]["result_name"]
79
+
80
+ # Set bbox_shift based on version
81
+ if args.version == "v15":
82
+ bbox_shift = 0 # v15 uses fixed bbox_shift
83
+ else:
84
+ bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
85
+
86
+ # Set output paths
87
+ input_basename = os.path.basename(video_path).split('.')[0]
88
+ audio_basename = os.path.basename(audio_path).split('.')[0]
89
+ output_basename = f"{input_basename}_{audio_basename}"
90
+
91
+ # Create temporary directories
92
+ temp_dir = os.path.join(args.result_dir, f"{args.version}")
93
+ os.makedirs(temp_dir, exist_ok=True)
94
+
95
+ # Set result save paths
96
+ result_img_save_path = os.path.join(temp_dir, output_basename)
97
+ crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
98
+ os.makedirs(result_img_save_path, exist_ok=True)
99
+
100
+ # Set output video paths
101
+ if args.output_vid_name is None:
102
+ output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
103
+ else:
104
+ output_vid_name = os.path.join(temp_dir, args.output_vid_name)
105
+ output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
106
+
107
+ # Extract frames from source video
108
+ if get_file_type(video_path) == "video":
109
+ save_dir_full = os.path.join(temp_dir, input_basename)
110
+ os.makedirs(save_dir_full, exist_ok=True)
111
+ cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
112
+ os.system(cmd)
113
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
114
+ fps = get_video_fps(video_path)
115
+ elif get_file_type(video_path) == "image":
116
+ input_img_list = [video_path]
117
+ fps = args.fps
118
+ elif os.path.isdir(video_path):
119
+ input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
120
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
121
+ fps = args.fps
122
+ else:
123
+ raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
124
 
125
+ # Extract audio features
126
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
127
+ whisper_chunks = audio_processor.get_whisper_chunk(
128
+ whisper_input_features,
129
+ device,
130
+ weight_dtype,
131
+ whisper,
132
+ librosa_length,
133
+ fps=fps,
134
+ audio_padding_length_left=args.audio_padding_length_left,
135
+ audio_padding_length_right=args.audio_padding_length_right,
136
+ )
137
+
138
+ # Preprocess input images
139
+ if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
140
+ print("Using saved coordinates")
141
+ with open(crop_coord_save_path, 'rb') as f:
142
+ coord_list = pickle.load(f)
143
+ frame_list = read_imgs(input_img_list)
144
+ else:
145
+ print("Extracting landmarks... time-consuming operation")
146
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
147
+ with open(crop_coord_save_path, 'wb') as f:
148
+ pickle.dump(coord_list, f)
149
+
150
+ print(f"Number of frames: {len(frame_list)}")
151
+
152
+ # Process each frame
153
+ input_latent_list = []
154
+ for bbox, frame in zip(coord_list, frame_list):
155
+ if bbox == coord_placeholder:
156
+ continue
157
+ x1, y1, x2, y2 = bbox
158
+ if args.version == "v15":
159
+ y2 = y2 + args.extra_margin
160
+ y2 = min(y2, frame.shape[0])
161
+ crop_frame = frame[y1:y2, x1:x2]
162
+ crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
163
+ latents = vae.get_latents_for_unet(crop_frame)
164
+ input_latent_list.append(latents)
165
 
166
+ # Smooth first and last frames
167
+ frame_list_cycle = frame_list + frame_list[::-1]
168
+ coord_list_cycle = coord_list + coord_list[::-1]
169
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
170
+
171
+ # Batch inference
172
+ print("Starting inference")
173
+ video_num = len(whisper_chunks)
174
+ batch_size = args.batch_size
175
+ gen = datagen(
176
+ whisper_chunks=whisper_chunks,
177
+ vae_encode_latents=input_latent_list_cycle,
178
+ batch_size=batch_size,
179
+ delay_frame=0,
180
+ device=device,
181
+ )
182
+
183
+ res_frame_list = []
184
+ total = int(np.ceil(float(video_num) / batch_size))
185
+
186
+ # Execute inference
187
+ for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
188
+ audio_feature_batch = pe(whisper_batch)
189
+ latent_batch = latent_batch.to(dtype=unet.model.dtype)
190
 
191
+ pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
192
+ recon = vae.decode_latents(pred_latents)
193
+ for res_frame in recon:
194
+ res_frame_list.append(res_frame)
195
+
196
+ # Pad generated images to original video size
197
+ print("Padding generated images to original video size")
198
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
199
+ bbox = coord_list_cycle[i%(len(coord_list_cycle))]
200
+ ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
201
+ x1, y1, x2, y2 = bbox
202
+ if args.version == "v15":
203
+ y2 = y2 + args.extra_margin
204
+ y2 = min(y2, frame.shape[0])
205
+ try:
206
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
207
+ except:
208
+ continue
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # Merge results with version-specific parameters
211
+ if args.version == "v15":
212
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
213
+ else:
214
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
215
+ cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
 
 
 
 
 
 
 
 
216
 
217
+ # Save prediction results
218
+ temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
219
+ cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
220
+ print("Video generation command:", cmd_img2video)
221
+ os.system(cmd_img2video)
222
+
223
+ cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
224
+ print("Audio combination command:", cmd_combine_audio)
225
+ os.system(cmd_combine_audio)
226
+
227
+ # Clean up temporary files
228
+ shutil.rmtree(result_img_save_path)
229
+ os.remove(temp_vid_path)
230
+
231
+ shutil.rmtree(save_dir_full)
232
+ if not args.saved_coord:
233
+ os.remove(crop_coord_save_path)
234
+
235
+ print(f"Results saved to {output_vid_name}")
236
+ except Exception as e:
237
+ print("Error occurred during processing:", e)
238
 
239
  if __name__ == "__main__":
240
  parser = argparse.ArgumentParser()
241
  parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
 
 
 
242
  parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
 
 
 
 
 
 
 
 
 
 
 
243
  parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
244
  parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
245
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
246
  parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
247
+ parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
248
+ parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
249
+ parser.add_argument("--result_dir", default='./results', help="Directory for output results")
250
+ parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
251
+ parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
252
  parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
253
  parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
254
+ parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
255
+ parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
256
+ parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
257
+ parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
258
+ parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
259
+ parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
260
+ parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
261
+ parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
262
+ parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
263
  args = parser.parse_args()
264
  main(args)
scripts/inference_alpha.py DELETED
@@ -1,252 +0,0 @@
1
- import os
2
- import cv2
3
- import math
4
- import copy
5
- import torch
6
- import glob
7
- import shutil
8
- import pickle
9
- import argparse
10
- import subprocess
11
- import numpy as np
12
- from tqdm import tqdm
13
- from omegaconf import OmegaConf
14
- from transformers import WhisperModel
15
-
16
- from musetalk.utils.blending import get_image
17
- from musetalk.utils.face_parsing import FaceParsing
18
- from musetalk.utils.audio_processor import AudioProcessor
19
- from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
20
- from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
21
-
22
-
23
- @torch.no_grad()
24
- def main(args):
25
- # Configure ffmpeg path
26
- if args.ffmpeg_path not in os.getenv('PATH'):
27
- print("Adding ffmpeg to PATH")
28
- os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
29
-
30
- # Set computing device
31
- device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
32
-
33
- # Load model weights
34
- vae, unet, pe = load_all_model(
35
- unet_model_path=args.unet_model_path,
36
- vae_type=args.vae_type,
37
- unet_config=args.unet_config,
38
- device=device
39
- )
40
- timesteps = torch.tensor([0], device=device)
41
-
42
- # Convert models to half precision if float16 is enabled
43
- if args.use_float16:
44
- pe = pe.half()
45
- vae.vae = vae.vae.half()
46
- unet.model = unet.model.half()
47
-
48
- # Move models to specified device
49
- pe = pe.to(device)
50
- vae.vae = vae.vae.to(device)
51
- unet.model = unet.model.to(device)
52
-
53
- # Initialize audio processor and Whisper model
54
- audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
55
- weight_dtype = unet.model.dtype
56
- whisper = WhisperModel.from_pretrained(args.whisper_dir)
57
- whisper = whisper.to(device=device, dtype=weight_dtype).eval()
58
- whisper.requires_grad_(False)
59
-
60
- # Initialize face parser
61
- fp = FaceParsing(left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width)
62
-
63
- # Load inference configuration
64
- inference_config = OmegaConf.load(args.inference_config)
65
- print("Loaded inference config:", inference_config)
66
-
67
- # Process each task
68
- for task_id in inference_config:
69
- try:
70
- # Get task configuration
71
- video_path = inference_config[task_id]["video_path"]
72
- audio_path = inference_config[task_id]["audio_path"]
73
- if "result_name" in inference_config[task_id]:
74
- args.output_vid_name = inference_config[task_id]["result_name"]
75
- bbox_shift = args.bbox_shift
76
- # Set output paths
77
- input_basename = os.path.basename(video_path).split('.')[0]
78
- audio_basename = os.path.basename(audio_path).split('.')[0]
79
- output_basename = f"{input_basename}_{audio_basename}"
80
-
81
- # Create temporary directories
82
- temp_dir = os.path.join(args.result_dir, "frames_result")
83
- os.makedirs(temp_dir, exist_ok=True)
84
-
85
- # Set result save paths
86
- result_img_save_path = os.path.join(temp_dir, output_basename) # related to video & audio inputs
87
- crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") # only related to video input
88
- os.makedirs(result_img_save_path, exist_ok=True)
89
- # Set output video paths
90
- if args.output_vid_name is None:
91
- output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
92
- else:
93
- output_vid_name = os.path.join(temp_dir, args.output_vid_name)
94
- output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
95
-
96
- # Skip if output file already exists
97
- if os.path.exists(output_vid_name):
98
- print(f"{output_vid_name} already exists, skipping!")
99
- continue
100
-
101
- # Extract frames from source video
102
- if get_file_type(video_path) == "video":
103
- save_dir_full = os.path.join(temp_dir, input_basename)
104
- os.makedirs(save_dir_full, exist_ok=True)
105
- cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
106
- os.system(cmd)
107
- input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
108
- fps = get_video_fps(video_path)
109
- elif get_file_type(video_path) == "image":
110
- input_img_list = [video_path]
111
- fps = args.fps
112
- elif os.path.isdir(video_path):
113
- input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
114
- input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
115
- fps = args.fps
116
- else:
117
- raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
118
-
119
- # Extract audio features
120
- whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
121
- whisper_chunks = audio_processor.get_whisper_chunk(
122
- whisper_input_features,
123
- device,
124
- weight_dtype,
125
- whisper,
126
- librosa_length,
127
- fps=fps,
128
- audio_padding_length_left=args.audio_padding_length_left,
129
- audio_padding_length_right=args.audio_padding_length_right,
130
- )
131
-
132
- # Preprocess input images
133
- if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
134
- print("Using saved coordinates")
135
- with open(crop_coord_save_path, 'rb') as f:
136
- coord_list = pickle.load(f)
137
- frame_list = read_imgs(input_img_list)
138
- else:
139
- print("Extracting landmarks... time-consuming operation")
140
- coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
141
- with open(crop_coord_save_path, 'wb') as f:
142
- pickle.dump(coord_list, f)
143
-
144
- print(f"Number of frames: {len(frame_list)}")
145
-
146
- # Process each frame
147
- input_latent_list = []
148
- for bbox, frame in zip(coord_list, frame_list):
149
- if bbox == coord_placeholder:
150
- continue
151
- x1, y1, x2, y2 = bbox
152
- y2 = y2 + args.extra_margin
153
- y2 = min(y2, frame.shape[0])
154
- crop_frame = frame[y1:y2, x1:x2]
155
- crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
156
- latents = vae.get_latents_for_unet(crop_frame)
157
- input_latent_list.append(latents)
158
-
159
- # Smooth first and last frames
160
- frame_list_cycle = frame_list + frame_list[::-1]
161
- coord_list_cycle = coord_list + coord_list[::-1]
162
- input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
163
-
164
- # Batch inference
165
- print("Starting inference")
166
- video_num = len(whisper_chunks)
167
- batch_size = args.batch_size
168
- gen = datagen(
169
- whisper_chunks=whisper_chunks,
170
- vae_encode_latents=input_latent_list_cycle,
171
- batch_size=batch_size,
172
- delay_frame=0,
173
- device=device,
174
- )
175
-
176
- res_frame_list = []
177
- total = int(np.ceil(float(video_num) / batch_size))
178
-
179
- # Execute inference
180
- for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
181
- audio_feature_batch = pe(whisper_batch)
182
- latent_batch = latent_batch.to(dtype=unet.model.dtype)
183
-
184
- pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
185
- recon = vae.decode_latents(pred_latents)
186
- for res_frame in recon:
187
- res_frame_list.append(res_frame)
188
-
189
- # Pad generated images to original video size
190
- print("Padding generated images to original video size")
191
- for i, res_frame in enumerate(tqdm(res_frame_list)):
192
- bbox = coord_list_cycle[i%(len(coord_list_cycle))]
193
- ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
194
- x1, y1, x2, y2 = bbox
195
- y2 = y2 + args.extra_margin
196
- y2 = min(y2, frame.shape[0])
197
- try:
198
- res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
199
- except:
200
- continue
201
-
202
- # Merge results
203
- combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
204
- cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
205
-
206
- # Save prediction results
207
- temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
208
- cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
209
- print("Video generation command:", cmd_img2video)
210
- os.system(cmd_img2video)
211
-
212
- cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
213
- print("Audio combination command:", cmd_combine_audio)
214
- os.system(cmd_combine_audio)
215
-
216
- # Clean up temporary files
217
- shutil.rmtree(result_img_save_path)
218
- os.remove(temp_vid_path)
219
-
220
- shutil.rmtree(save_dir_full)
221
- if not args.saved_coord:
222
- os.remove(crop_coord_save_path)
223
-
224
- print(f"Results saved to {output_vid_name}")
225
- except Exception as e:
226
- print("Error occurred during processing:", e)
227
-
228
- if __name__ == "__main__":
229
- parser = argparse.ArgumentParser()
230
- parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
231
- parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
232
- parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
233
- parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
234
- parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
235
- parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
236
- parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
237
- parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
238
- parser.add_argument("--result_dir", default='./results', help="Directory for output results")
239
- parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
240
- parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
241
- parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
242
- parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
243
- parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
244
- parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
245
- parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
246
- parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
247
- parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
248
- parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
249
- parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
250
- parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
251
- args = parser.parse_args()
252
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/realtime_inference.py CHANGED
@@ -10,26 +10,22 @@ import sys
10
  from tqdm import tqdm
11
  import copy
12
  import json
13
- from musetalk.utils.utils import get_file_type,get_video_fps,datagen
14
- from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
15
- from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
 
 
 
16
  from musetalk.utils.utils import load_all_model
17
- import shutil
18
 
 
19
  import threading
20
  import queue
21
-
22
  import time
23
 
24
- # load model weights
25
- audio_processor, vae, unet, pe = load_all_model()
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- timesteps = torch.tensor([0], device=device)
28
- pe = pe.half()
29
- vae.vae = vae.vae.half()
30
- unet.model = unet.model.half()
31
 
32
- def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
33
  cap = cv2.VideoCapture(vid_path)
34
  count = 0
35
  while True:
@@ -42,35 +38,43 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
42
  else:
43
  break
44
 
 
45
  def osmakedirs(path_list):
46
  for path in path_list:
47
  os.makedirs(path) if not os.path.exists(path) else None
48
-
49
 
50
- @torch.no_grad()
 
51
  class Avatar:
52
  def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
53
  self.avatar_id = avatar_id
54
  self.video_path = video_path
55
  self.bbox_shift = bbox_shift
56
- self.avatar_path = f"./results/avatars/{avatar_id}"
57
- self.full_imgs_path = f"{self.avatar_path}/full_imgs"
 
 
 
 
 
 
58
  self.coords_path = f"{self.avatar_path}/coords.pkl"
59
- self.latents_out_path= f"{self.avatar_path}/latents.pt"
60
  self.video_out_path = f"{self.avatar_path}/vid_output/"
61
- self.mask_out_path =f"{self.avatar_path}/mask"
62
- self.mask_coords_path =f"{self.avatar_path}/mask_coords.pkl"
63
  self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
64
  self.avatar_info = {
65
- "avatar_id":avatar_id,
66
- "video_path":video_path,
67
- "bbox_shift":bbox_shift
 
68
  }
69
  self.preparation = preparation
70
  self.batch_size = batch_size
71
  self.idx = 0
72
  self.init()
73
-
74
  def init(self):
75
  if self.preparation:
76
  if os.path.exists(self.avatar_path):
@@ -80,7 +84,7 @@ class Avatar:
80
  print("*********************************")
81
  print(f" creating avator: {self.avatar_id}")
82
  print("*********************************")
83
- osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
84
  self.prepare_material()
85
  else:
86
  self.input_latent_list_cycle = torch.load(self.latents_out_path)
@@ -98,16 +102,16 @@ class Avatar:
98
  print("*********************************")
99
  print(f" creating avator: {self.avatar_id}")
100
  print("*********************************")
101
- osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
102
  self.prepare_material()
103
- else:
104
  if not os.path.exists(self.avatar_path):
105
  print(f"{self.avatar_id} does not exist, you should set preparation to True")
106
  sys.exit()
107
 
108
  with open(self.avatar_info_path, "r") as f:
109
  avatar_info = json.load(f)
110
-
111
  if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
112
  response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
113
  if response.lower() == "c":
@@ -115,11 +119,11 @@ class Avatar:
115
  print("*********************************")
116
  print(f" creating avator: {self.avatar_id}")
117
  print("*********************************")
118
- osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
119
  self.prepare_material()
120
  else:
121
  sys.exit()
122
- else:
123
  self.input_latent_list_cycle = torch.load(self.latents_out_path)
124
  with open(self.coords_path, 'rb') as f:
125
  self.coord_list_cycle = pickle.load(f)
@@ -131,36 +135,40 @@ class Avatar:
131
  input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
132
  input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
133
  self.mask_list_cycle = read_imgs(input_mask_list)
134
-
135
  def prepare_material(self):
136
  print("preparing data materials ... ...")
137
  with open(self.avatar_info_path, "w") as f:
138
  json.dump(self.avatar_info, f)
139
-
140
  if os.path.isfile(self.video_path):
141
- video2imgs(self.video_path, self.full_imgs_path, ext = 'png')
142
  else:
143
  print(f"copy files in {self.video_path}")
144
  files = os.listdir(self.video_path)
145
  files.sort()
146
- files = [file for file in files if file.split(".")[-1]=="png"]
147
  for filename in files:
148
  shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
149
  input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
150
-
151
  print("extracting landmarks...")
152
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
153
  input_latent_list = []
154
  idx = -1
155
- # maker if the bbox is not sufficient
156
- coord_placeholder = (0.0,0.0,0.0,0.0)
157
  for bbox, frame in zip(coord_list, frame_list):
158
  idx = idx + 1
159
  if bbox == coord_placeholder:
160
  continue
161
  x1, y1, x2, y2 = bbox
 
 
 
 
162
  crop_frame = frame[y1:y2, x1:x2]
163
- resized_crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
164
  latents = vae.get_latents_for_unet(resized_crop_frame)
165
  input_latent_list.append(latents)
166
 
@@ -170,112 +178,116 @@ class Avatar:
170
  self.mask_coords_list_cycle = []
171
  self.mask_list_cycle = []
172
 
173
- for i,frame in enumerate(tqdm(self.frame_list_cycle)):
174
- cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png",frame)
175
-
176
- face_box = self.coord_list_cycle[i]
177
- mask,crop_box = get_image_prepare_material(frame,face_box)
178
- cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png",mask)
 
 
 
 
 
179
  self.mask_coords_list_cycle += [crop_box]
180
  self.mask_list_cycle.append(mask)
181
-
182
  with open(self.mask_coords_path, 'wb') as f:
183
  pickle.dump(self.mask_coords_list_cycle, f)
184
 
185
  with open(self.coords_path, 'wb') as f:
186
  pickle.dump(self.coord_list_cycle, f)
187
-
188
- torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
189
- #
190
-
191
- def process_frames(self,
192
- res_frame_queue,
193
- video_len,
194
- skip_save_images):
195
  print(video_len)
196
  while True:
197
- if self.idx>=video_len-1:
198
  break
199
  try:
200
  start = time.time()
201
  res_frame = res_frame_queue.get(block=True, timeout=1)
202
  except queue.Empty:
203
  continue
204
-
205
- bbox = self.coord_list_cycle[self.idx%(len(self.coord_list_cycle))]
206
- ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx%(len(self.frame_list_cycle))])
207
  x1, y1, x2, y2 = bbox
208
  try:
209
- res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
210
  except:
211
  continue
212
- mask = self.mask_list_cycle[self.idx%(len(self.mask_list_cycle))]
213
- mask_crop_box = self.mask_coords_list_cycle[self.idx%(len(self.mask_coords_list_cycle))]
214
- #combine_frame = get_image(ori_frame,res_frame,bbox)
215
  combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
216
 
217
  if skip_save_images is False:
218
- cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
219
  self.idx = self.idx + 1
220
 
221
- def inference(self,
222
- audio_path,
223
- out_vid_name,
224
- fps,
225
- skip_save_images):
226
- os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
227
  print("start inference")
228
  ############################################## extract audio feature ##############################################
229
  start_time = time.time()
230
- whisper_feature = audio_processor.audio2feat(audio_path)
231
- whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
 
 
 
 
 
 
 
 
 
 
232
  print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
233
  ############################################## inference batch by batch ##############################################
234
- video_num = len(whisper_chunks)
235
  res_frame_queue = queue.Queue()
236
  self.idx = 0
237
- # # Create a sub-thread and start it
238
  process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
239
  process_thread.start()
240
 
241
  gen = datagen(whisper_chunks,
242
- self.input_latent_list_cycle,
243
- self.batch_size)
244
  start_time = time.time()
245
  res_frame_list = []
246
-
247
- for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
248
- audio_feature_batch = torch.from_numpy(whisper_batch)
249
- audio_feature_batch = audio_feature_batch.to(device=unet.device,
250
- dtype=unet.model.dtype)
251
- audio_feature_batch = pe(audio_feature_batch)
252
- latent_batch = latent_batch.to(dtype=unet.model.dtype)
253
-
254
- pred_latents = unet.model(latent_batch,
255
- timesteps,
256
- encoder_hidden_states=audio_feature_batch).sample
257
  recon = vae.decode_latents(pred_latents)
258
  for res_frame in recon:
259
  res_frame_queue.put(res_frame)
260
  # Close the queue and sub-thread after all tasks are completed
261
  process_thread.join()
262
-
263
  if args.skip_save_images is True:
264
  print('Total process time of {} frames without saving images = {}s'.format(
265
- video_num,
266
- time.time()-start_time))
267
  else:
268
  print('Total process time of {} frames including saving images = {}s'.format(
269
- video_num,
270
- time.time()-start_time))
271
 
272
- if out_vid_name is not None and args.skip_save_images is False:
273
  # optional
274
- cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
275
  print(cmd_img2video)
276
  os.system(cmd_img2video)
277
 
278
- output_vid = os.path.join(self.video_out_path, out_vid_name+".mp4") # on
279
  cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
280
  print(cmd_combine_audio)
281
  os.system(cmd_combine_audio)
@@ -284,52 +296,95 @@ class Avatar:
284
  shutil.rmtree(f"{self.avatar_path}/tmp")
285
  print(f"result is save to {output_vid}")
286
  print("\n")
287
-
288
 
289
  if __name__ == "__main__":
290
  '''
291
  This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
292
  '''
293
-
294
  parser = argparse.ArgumentParser()
295
- parser.add_argument("--inference_config",
296
- type=str,
297
- default="configs/inference/realtime.yaml",
298
- )
299
- parser.add_argument("--fps",
300
- type=int,
301
- default=25,
302
- )
303
- parser.add_argument("--batch_size",
304
- type=int,
305
- default=4,
306
- )
 
 
 
 
 
 
 
 
 
307
  parser.add_argument("--skip_save_images",
308
- action="store_true",
309
- help="Whether skip saving images for better generation speed calculation",
310
- )
311
 
312
  args = parser.parse_args()
313
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  inference_config = OmegaConf.load(args.inference_config)
315
  print(inference_config)
316
-
317
-
318
  for avatar_id in inference_config:
319
  data_preparation = inference_config[avatar_id]["preparation"]
320
  video_path = inference_config[avatar_id]["video_path"]
321
- bbox_shift = inference_config[avatar_id]["bbox_shift"]
 
 
 
322
  avatar = Avatar(
323
- avatar_id = avatar_id,
324
- video_path = video_path,
325
- bbox_shift = bbox_shift,
326
- batch_size = args.batch_size,
327
- preparation= data_preparation)
328
-
329
  audio_clips = inference_config[avatar_id]["audio_clips"]
330
  for audio_num, audio_path in audio_clips.items():
331
- print("Inferring using:",audio_path)
332
- avatar.inference(audio_path,
333
- audio_num,
334
- args.fps,
335
- args.skip_save_images)
 
10
  from tqdm import tqdm
11
  import copy
12
  import json
13
+ from transformers import WhisperModel
14
+
15
+ from musetalk.utils.face_parsing import FaceParsing
16
+ from musetalk.utils.utils import datagen
17
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
18
+ from musetalk.utils.blending import get_image_prepare_material, get_image_blending
19
  from musetalk.utils.utils import load_all_model
20
+ from musetalk.utils.audio_processor import AudioProcessor
21
 
22
+ import shutil
23
  import threading
24
  import queue
 
25
  import time
26
 
 
 
 
 
 
 
 
27
 
28
+ def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
29
  cap = cv2.VideoCapture(vid_path)
30
  count = 0
31
  while True:
 
38
  else:
39
  break
40
 
41
+
42
  def osmakedirs(path_list):
43
  for path in path_list:
44
  os.makedirs(path) if not os.path.exists(path) else None
 
45
 
46
+
47
+ @torch.no_grad()
48
  class Avatar:
49
  def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
50
  self.avatar_id = avatar_id
51
  self.video_path = video_path
52
  self.bbox_shift = bbox_shift
53
+ # 根据版本设置不同的基础路径
54
+ if args.version == "v15":
55
+ self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
56
+ else: # v1
57
+ self.base_path = f"./results/avatars/{avatar_id}"
58
+
59
+ self.avatar_path = self.base_path
60
+ self.full_imgs_path = f"{self.avatar_path}/full_imgs"
61
  self.coords_path = f"{self.avatar_path}/coords.pkl"
62
+ self.latents_out_path = f"{self.avatar_path}/latents.pt"
63
  self.video_out_path = f"{self.avatar_path}/vid_output/"
64
+ self.mask_out_path = f"{self.avatar_path}/mask"
65
+ self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
66
  self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
67
  self.avatar_info = {
68
+ "avatar_id": avatar_id,
69
+ "video_path": video_path,
70
+ "bbox_shift": bbox_shift,
71
+ "version": args.version
72
  }
73
  self.preparation = preparation
74
  self.batch_size = batch_size
75
  self.idx = 0
76
  self.init()
77
+
78
  def init(self):
79
  if self.preparation:
80
  if os.path.exists(self.avatar_path):
 
84
  print("*********************************")
85
  print(f" creating avator: {self.avatar_id}")
86
  print("*********************************")
87
+ osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
88
  self.prepare_material()
89
  else:
90
  self.input_latent_list_cycle = torch.load(self.latents_out_path)
 
102
  print("*********************************")
103
  print(f" creating avator: {self.avatar_id}")
104
  print("*********************************")
105
+ osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
106
  self.prepare_material()
107
+ else:
108
  if not os.path.exists(self.avatar_path):
109
  print(f"{self.avatar_id} does not exist, you should set preparation to True")
110
  sys.exit()
111
 
112
  with open(self.avatar_info_path, "r") as f:
113
  avatar_info = json.load(f)
114
+
115
  if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
116
  response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
117
  if response.lower() == "c":
 
119
  print("*********************************")
120
  print(f" creating avator: {self.avatar_id}")
121
  print("*********************************")
122
+ osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
123
  self.prepare_material()
124
  else:
125
  sys.exit()
126
+ else:
127
  self.input_latent_list_cycle = torch.load(self.latents_out_path)
128
  with open(self.coords_path, 'rb') as f:
129
  self.coord_list_cycle = pickle.load(f)
 
135
  input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
136
  input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
137
  self.mask_list_cycle = read_imgs(input_mask_list)
138
+
139
  def prepare_material(self):
140
  print("preparing data materials ... ...")
141
  with open(self.avatar_info_path, "w") as f:
142
  json.dump(self.avatar_info, f)
143
+
144
  if os.path.isfile(self.video_path):
145
+ video2imgs(self.video_path, self.full_imgs_path, ext='png')
146
  else:
147
  print(f"copy files in {self.video_path}")
148
  files = os.listdir(self.video_path)
149
  files.sort()
150
+ files = [file for file in files if file.split(".")[-1] == "png"]
151
  for filename in files:
152
  shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
153
  input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
154
+
155
  print("extracting landmarks...")
156
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
157
  input_latent_list = []
158
  idx = -1
159
+ # maker if the bbox is not sufficient
160
+ coord_placeholder = (0.0, 0.0, 0.0, 0.0)
161
  for bbox, frame in zip(coord_list, frame_list):
162
  idx = idx + 1
163
  if bbox == coord_placeholder:
164
  continue
165
  x1, y1, x2, y2 = bbox
166
+ if args.version == "v15":
167
+ y2 = y2 + args.extra_margin
168
+ y2 = min(y2, frame.shape[0])
169
+ coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
170
  crop_frame = frame[y1:y2, x1:x2]
171
+ resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
172
  latents = vae.get_latents_for_unet(resized_crop_frame)
173
  input_latent_list.append(latents)
174
 
 
178
  self.mask_coords_list_cycle = []
179
  self.mask_list_cycle = []
180
 
181
+ for i, frame in enumerate(tqdm(self.frame_list_cycle)):
182
+ cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
183
+
184
+ x1, y1, x2, y2 = self.coord_list_cycle[i]
185
+ if args.version == "v15":
186
+ mode = args.parsing_mode
187
+ else:
188
+ mode = "raw"
189
+ mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
190
+
191
+ cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
192
  self.mask_coords_list_cycle += [crop_box]
193
  self.mask_list_cycle.append(mask)
194
+
195
  with open(self.mask_coords_path, 'wb') as f:
196
  pickle.dump(self.mask_coords_list_cycle, f)
197
 
198
  with open(self.coords_path, 'wb') as f:
199
  pickle.dump(self.coord_list_cycle, f)
200
+
201
+ torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
202
+
203
+ def process_frames(self, res_frame_queue, video_len, skip_save_images):
 
 
 
 
204
  print(video_len)
205
  while True:
206
+ if self.idx >= video_len - 1:
207
  break
208
  try:
209
  start = time.time()
210
  res_frame = res_frame_queue.get(block=True, timeout=1)
211
  except queue.Empty:
212
  continue
213
+
214
+ bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))]
215
+ ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))])
216
  x1, y1, x2, y2 = bbox
217
  try:
218
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
219
  except:
220
  continue
221
+ mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
222
+ mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
 
223
  combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
224
 
225
  if skip_save_images is False:
226
+ cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
227
  self.idx = self.idx + 1
228
 
229
+ def inference(self, audio_path, out_vid_name, fps, skip_save_images):
230
+ os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
 
 
 
 
231
  print("start inference")
232
  ############################################## extract audio feature ##############################################
233
  start_time = time.time()
234
+ # Extract audio features
235
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
236
+ whisper_chunks = audio_processor.get_whisper_chunk(
237
+ whisper_input_features,
238
+ device,
239
+ weight_dtype,
240
+ whisper,
241
+ librosa_length,
242
+ fps=fps,
243
+ audio_padding_length_left=args.audio_padding_length_left,
244
+ audio_padding_length_right=args.audio_padding_length_right,
245
+ )
246
  print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
247
  ############################################## inference batch by batch ##############################################
248
+ video_num = len(whisper_chunks)
249
  res_frame_queue = queue.Queue()
250
  self.idx = 0
251
+ # Create a sub-thread and start it
252
  process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
253
  process_thread.start()
254
 
255
  gen = datagen(whisper_chunks,
256
+ self.input_latent_list_cycle,
257
+ self.batch_size)
258
  start_time = time.time()
259
  res_frame_list = []
260
+
261
+ for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
262
+ audio_feature_batch = pe(whisper_batch.to(device))
263
+ latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
264
+
265
+ pred_latents = unet.model(latent_batch,
266
+ timesteps,
267
+ encoder_hidden_states=audio_feature_batch).sample
268
+ pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
 
 
269
  recon = vae.decode_latents(pred_latents)
270
  for res_frame in recon:
271
  res_frame_queue.put(res_frame)
272
  # Close the queue and sub-thread after all tasks are completed
273
  process_thread.join()
274
+
275
  if args.skip_save_images is True:
276
  print('Total process time of {} frames without saving images = {}s'.format(
277
+ video_num,
278
+ time.time() - start_time))
279
  else:
280
  print('Total process time of {} frames including saving images = {}s'.format(
281
+ video_num,
282
+ time.time() - start_time))
283
 
284
+ if out_vid_name is not None and args.skip_save_images is False:
285
  # optional
286
+ cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
287
  print(cmd_img2video)
288
  os.system(cmd_img2video)
289
 
290
+ output_vid = os.path.join(self.video_out_path, out_vid_name + ".mp4") # on
291
  cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
292
  print(cmd_combine_audio)
293
  os.system(cmd_combine_audio)
 
296
  shutil.rmtree(f"{self.avatar_path}/tmp")
297
  print(f"result is save to {output_vid}")
298
  print("\n")
299
+
300
 
301
  if __name__ == "__main__":
302
  '''
303
  This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
304
  '''
305
+
306
  parser = argparse.ArgumentParser()
307
+ parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
308
+ parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
309
+ parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
310
+ parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
311
+ parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
312
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
313
+ parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
314
+ parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
315
+ parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
316
+ parser.add_argument("--result_dir", default='./results', help="Directory for output results")
317
+ parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
318
+ parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
319
+ parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
320
+ parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
321
+ parser.add_argument("--batch_size", type=int, default=25, help="Batch size for inference")
322
+ parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
323
+ parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
324
+ parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
325
+ parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
326
+ parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
327
+ parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
328
  parser.add_argument("--skip_save_images",
329
+ action="store_true",
330
+ help="Whether skip saving images for better generation speed calculation",
331
+ )
332
 
333
  args = parser.parse_args()
334
+
335
+ # Set computing device
336
+ device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
337
+
338
+ # Load model weights
339
+ vae, unet, pe = load_all_model(
340
+ unet_model_path=args.unet_model_path,
341
+ vae_type=args.vae_type,
342
+ unet_config=args.unet_config,
343
+ device=device
344
+ )
345
+ timesteps = torch.tensor([0], device=device)
346
+
347
+ pe = pe.half().to(device)
348
+ vae.vae = vae.vae.half().to(device)
349
+ unet.model = unet.model.half().to(device)
350
+
351
+ # Initialize audio processor and Whisper model
352
+ audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
353
+ weight_dtype = unet.model.dtype
354
+ whisper = WhisperModel.from_pretrained(args.whisper_dir)
355
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
356
+ whisper.requires_grad_(False)
357
+
358
+ # Initialize face parser with configurable parameters based on version
359
+ if args.version == "v15":
360
+ fp = FaceParsing(
361
+ left_cheek_width=args.left_cheek_width,
362
+ right_cheek_width=args.right_cheek_width
363
+ )
364
+ else: # v1
365
+ fp = FaceParsing()
366
+
367
  inference_config = OmegaConf.load(args.inference_config)
368
  print(inference_config)
369
+
 
370
  for avatar_id in inference_config:
371
  data_preparation = inference_config[avatar_id]["preparation"]
372
  video_path = inference_config[avatar_id]["video_path"]
373
+ if args.version == "v15":
374
+ bbox_shift = 0
375
+ else:
376
+ bbox_shift = inference_config[avatar_id]["bbox_shift"]
377
  avatar = Avatar(
378
+ avatar_id=avatar_id,
379
+ video_path=video_path,
380
+ bbox_shift=bbox_shift,
381
+ batch_size=args.batch_size,
382
+ preparation=data_preparation)
383
+
384
  audio_clips = inference_config[avatar_id]["audio_clips"]
385
  for audio_num, audio_path in audio_clips.items():
386
+ print("Inferring using:", audio_path)
387
+ avatar.inference(audio_path,
388
+ audio_num,
389
+ args.fps,
390
+ args.skip_save_images)