zkangchen commited on
Commit
3b3eed4
·
1 Parent(s): 8e2f242

<enhance>: modified inference codes

Browse files

1. Can set bbox_shift in configs/inference/test.yaml
2. Do not need to pip install whisper now

.gitignore CHANGED
@@ -4,7 +4,7 @@
4
  .vscode/
5
  *.pyc
6
  .ipynb_checkpoints
7
- models/
8
  results/
9
- data/audio/*.WAV
10
  data/video/*.mp4
 
4
  .vscode/
5
  *.pyc
6
  .ipynb_checkpoints
7
+ models
8
  results/
9
+ data/audio/*.wav
10
  data/video/*.mp4
README.md CHANGED
@@ -175,11 +175,6 @@ We recommend a python version >=3.10 and cuda version =11.7. Then build environm
175
  ```shell
176
  pip install -r requirements.txt
177
  ```
178
- ### whisper
179
- install whisper to extract audio feature (only encoder)
180
- ```
181
- pip install --editable ./musetalk/whisper
182
- ```
183
 
184
  ### mmlab packages
185
  ```bash
@@ -256,13 +251,13 @@ As a complete solution to virtual human generation, you are suggested to first a
256
 
257
  # Note
258
 
259
- If you want to launch online video chats, you are suggested to generate videos using MuseV and apply necessary pre-processing such as face detection in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
260
 
261
 
262
  # Acknowledgement
263
- 1. We thank open-source components like [whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
264
- 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers).
265
- 1. MuseTalk has been built on `HDTF` datasets.
266
 
267
  Thanks for open-sourcing!
268
 
 
175
  ```shell
176
  pip install -r requirements.txt
177
  ```
 
 
 
 
 
178
 
179
  ### mmlab packages
180
  ```bash
 
251
 
252
  # Note
253
 
254
+ If you want to launch online video chats, you are suggested to generate videos using MuseV and apply necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
255
 
256
 
257
  # Acknowledgement
258
+ 1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
259
+ 1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
260
+ 1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
261
 
262
  Thanks for open-sourcing!
263
 
configs/inference/test.yaml CHANGED
@@ -1,9 +1,10 @@
1
  task_0:
2
- video_path: "data/video/monalisa.mp4"
3
- audio_path: "data/audio/monalisa.wav"
4
 
5
  task_1:
6
  video_path: "data/video/sun.mp4"
7
  audio_path: "data/audio/sun.wav"
 
8
 
9
 
 
1
  task_0:
2
+ video_path: "data/video/yongen.mp4"
3
+ audio_path: "data/audio/yongen.wav"
4
 
5
  task_1:
6
  video_path: "data/video/sun.mp4"
7
  audio_path: "data/audio/sun.wav"
8
+ bbox_shift: -7
9
 
10
 
data/audio/{monalisa.wav → yongen.wav} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:843ab9b94cbbf67072aa2e8d3c6397a2fc7537ff47402922a9004c18d2222ae2
3
- size 6971436
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b775c363c968428d1d6df4456495e4c11f00e3204d3082e51caff415ec0e2ba
3
+ size 1536078
data/video/{monalisa.mp4 → yongen.mp4} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb6c07fb0aa57cf287a54232b1962e4de689fb98b431a502fb1504350ba441c6
3
- size 6906049
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1effa976d410571cd185554779d6d43a6ba636e0e3401385db1d607daa46441f
3
+ size 1870923
musetalk/utils/blending.py CHANGED
@@ -52,7 +52,6 @@ def get_image(image,face,face_box,upper_boundary_ratio = 0.5,expand=1.2):
52
  blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
53
  mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
54
  mask_image = Image.fromarray(mask_array)
55
- mask_image.save("./debug_mask.png")
56
 
57
  face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
58
  body.paste(face_large, crop_box[:2], mask_image)
 
52
  blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
53
  mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
54
  mask_image = Image.fromarray(mask_array)
 
55
 
56
  face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
57
  body.paste(face_large, crop_box[:2], mask_image)
musetalk/whisper/audio2feature.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
- #import whisper
3
- from whisper import load_model
4
- #import whisper.whispher as whiisper
5
  import soundfile as sf
6
  import numpy as np
7
  import time
@@ -9,11 +7,12 @@ import sys
9
  sys.path.append("..")
10
 
11
  class Audio2Feature():
12
- def __init__(self, whisper_model_type="tiny",model_path="./checkpoints/wisper_tiny.pt"):
 
 
13
  self.whisper_model_type = whisper_model_type
14
  self.model = load_model(model_path) #
15
 
16
-
17
  def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
18
  """
19
  Get sliced features based on a given index
 
1
  import os
2
+ from .whisper import load_model
 
 
3
  import soundfile as sf
4
  import numpy as np
5
  import time
 
7
  sys.path.append("..")
8
 
9
  class Audio2Feature():
10
+ def __init__(self,
11
+ whisper_model_type="tiny",
12
+ model_path="./models/whisper/tiny.pt"):
13
  self.whisper_model_type = whisper_model_type
14
  self.model = load_model(model_path) #
15
 
 
16
  def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
17
  """
18
  Get sliced features based on a given index
musetalk/whisper/requirements.txt DELETED
@@ -1,6 +0,0 @@
1
- numpy
2
- torch
3
- tqdm
4
- more-itertools
5
- transformers>=4.19.0
6
- ffmpeg-python==0.2.0
 
 
 
 
 
 
 
musetalk/whisper/setup.py DELETED
@@ -1,24 +0,0 @@
1
- import os
2
-
3
- import pkg_resources
4
- from setuptools import setup, find_packages
5
-
6
- setup(
7
- name="whisper",
8
- py_modules=["whisper"],
9
- version="1.0",
10
- description="",
11
- author="OpenAI",
12
- packages=find_packages(exclude=["tests*"]),
13
- install_requires=[
14
- str(r)
15
- for r in pkg_resources.parse_requirements(
16
- open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17
- )
18
- ],
19
- entry_points = {
20
- 'console_scripts': ['whisper=whisper.transcribe:cli'],
21
- },
22
- include_package_data=True,
23
- extras_require={'dev': ['pytest']},
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
musetalk/whisper/whisper.egg-info/PKG-INFO DELETED
@@ -1,5 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: whisper
3
- Version: 1.0
4
- Author: OpenAI
5
- Provides-Extra: dev
 
 
 
 
 
 
musetalk/whisper/whisper.egg-info/SOURCES.txt DELETED
@@ -1,18 +0,0 @@
1
- setup.py
2
- whisper/__init__.py
3
- whisper/__main__.py
4
- whisper/audio.py
5
- whisper/decoding.py
6
- whisper/model.py
7
- whisper/tokenizer.py
8
- whisper/transcribe.py
9
- whisper/utils.py
10
- whisper.egg-info/PKG-INFO
11
- whisper.egg-info/SOURCES.txt
12
- whisper.egg-info/dependency_links.txt
13
- whisper.egg-info/entry_points.txt
14
- whisper.egg-info/requires.txt
15
- whisper.egg-info/top_level.txt
16
- whisper/normalizers/__init__.py
17
- whisper/normalizers/basic.py
18
- whisper/normalizers/english.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
musetalk/whisper/whisper.egg-info/dependency_links.txt DELETED
@@ -1 +0,0 @@
1
-
 
 
musetalk/whisper/whisper.egg-info/entry_points.txt DELETED
@@ -1,2 +0,0 @@
1
- [console_scripts]
2
- whisper = whisper.transcribe:cli
 
 
 
musetalk/whisper/whisper.egg-info/requires.txt DELETED
@@ -1,9 +0,0 @@
1
- numpy
2
- torch
3
- tqdm
4
- more-itertools
5
- transformers>=4.19.0
6
- ffmpeg-python==0.2.0
7
-
8
- [dev]
9
- pytest
 
 
 
 
 
 
 
 
 
 
musetalk/whisper/whisper.egg-info/top_level.txt DELETED
@@ -1 +0,0 @@
1
- whisper
 
 
scripts/inference.py CHANGED
@@ -13,6 +13,7 @@ from musetalk.utils.utils import get_file_type,get_video_fps,datagen
13
  from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
14
  from musetalk.utils.blending import get_image
15
  from musetalk.utils.utils import load_all_model
 
16
 
17
  # load model weights
18
  audio_processor,vae,unet,pe = load_all_model()
@@ -26,6 +27,7 @@ def main(args):
26
  for task_id in inference_config:
27
  video_path = inference_config[task_id]["video_path"]
28
  audio_path = inference_config[task_id]["audio_path"]
 
29
 
30
  input_basename = os.path.basename(video_path).split('.')[0]
31
  audio_basename = os.path.basename(audio_path).split('.')[0]
@@ -42,7 +44,7 @@ def main(args):
42
  if get_file_type(video_path)=="video":
43
  save_dir_full = os.path.join(args.result_dir, input_basename)
44
  os.makedirs(save_dir_full,exist_ok = True)
45
- cmd = f"ffmpeg -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
46
  os.system(cmd)
47
  input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
48
  fps = get_video_fps(video_path)
@@ -62,7 +64,7 @@ def main(args):
62
  frame_list = read_imgs(input_img_list)
63
  else:
64
  print("extracting landmarks...time consuming")
65
- coord_list, frame_list = get_landmark_and_bbox(input_img_list,args.bbox_shift)
66
  with open(crop_coord_save_path, 'wb') as f:
67
  pickle.dump(coord_list, f)
68
 
@@ -117,24 +119,26 @@ def main(args):
117
  print(cmd_img2video)
118
  os.system(cmd_img2video)
119
 
120
- cmd_combine_audio = f"ffmpeg -i {audio_path} -i temp.mp4 {output_vid_name} -y"
121
  print(cmd_combine_audio)
122
  os.system(cmd_combine_audio)
123
 
124
- os.system("rm temp.mp4")
125
- os.system(f"rm -rf {result_img_save_path}")
126
  print(f"result is save to {output_vid_name}")
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser()
130
- parser.add_argument("--inference_config",type=str, default="configs/inference/test_img.yaml")
131
- parser.add_argument("--bbox_shift",type=int, default=0)
132
  parser.add_argument("--result_dir", default='./results', help="path to output")
133
 
134
- parser.add_argument("--fps",type=int, default=25)
135
- parser.add_argument("--batch_size",type=int, default=8)
136
- parser.add_argument("--output_vid_name",type=str,default='')
137
- parser.add_argument("--use_saved_coord",action="store_true", help='use saved coordinate to save time')
 
 
138
 
139
 
140
  args = parser.parse_args()
 
13
  from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
14
  from musetalk.utils.blending import get_image
15
  from musetalk.utils.utils import load_all_model
16
+ import shutil
17
 
18
  # load model weights
19
  audio_processor,vae,unet,pe = load_all_model()
 
27
  for task_id in inference_config:
28
  video_path = inference_config[task_id]["video_path"]
29
  audio_path = inference_config[task_id]["audio_path"]
30
+ bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
31
 
32
  input_basename = os.path.basename(video_path).split('.')[0]
33
  audio_basename = os.path.basename(audio_path).split('.')[0]
 
44
  if get_file_type(video_path)=="video":
45
  save_dir_full = os.path.join(args.result_dir, input_basename)
46
  os.makedirs(save_dir_full,exist_ok = True)
47
+ cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
48
  os.system(cmd)
49
  input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
50
  fps = get_video_fps(video_path)
 
64
  frame_list = read_imgs(input_img_list)
65
  else:
66
  print("extracting landmarks...time consuming")
67
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
68
  with open(crop_coord_save_path, 'wb') as f:
69
  pickle.dump(coord_list, f)
70
 
 
119
  print(cmd_img2video)
120
  os.system(cmd_img2video)
121
 
122
+ cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
123
  print(cmd_combine_audio)
124
  os.system(cmd_combine_audio)
125
 
126
+ os.remove("temp.mp4")
127
+ shutil.rmtree(result_img_save_path)
128
  print(f"result is save to {output_vid_name}")
129
 
130
  if __name__ == "__main__":
131
  parser = argparse.ArgumentParser()
132
+ parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
133
+ parser.add_argument("--bbox_shift", type=int, default=0)
134
  parser.add_argument("--result_dir", default='./results', help="path to output")
135
 
136
+ parser.add_argument("--fps", type=int, default=25)
137
+ parser.add_argument("--batch_size", type=int, default=8)
138
+ parser.add_argument("--output_vid_name", type=str,default='')
139
+ parser.add_argument("--use_saved_coord",
140
+ action="store_true",
141
+ help='use saved coordinate to save time')
142
 
143
 
144
  args = parser.parse_args()