gz412 commited on
Commit
bd84a81
·
1 Parent(s): 1a04b2a

test app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -34
app.py CHANGED
@@ -1,47 +1,91 @@
1
- import os
2
  import sys
3
  import torch
4
- import spaces
5
 
6
- print("===== Application Startup =====")
 
 
 
 
 
 
 
7
 
8
- # 不要强制关掉 CUDA,注释掉下面这一行
9
- # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
10
 
11
- print("gz start")
 
12
 
13
- print("Python version:", sys.version)
14
- print("Torch version:", torch.__version__)
15
- print("CUDA available:", torch.cuda.is_available())
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # 尝试分配一个 Tensor 到 GPU
18
- try:
19
- if torch.cuda.is_available():
20
- device = torch.device("cuda")
21
- x = torch.rand((2, 3), device=device)
22
- y = torch.mm(x, x.T)
23
- print("Tensor allocated on GPU successfully:")
24
- print(y)
 
 
 
 
 
25
  else:
26
- print("CUDA not available, fallback to CPU")
27
- x = torch.rand((2, 3))
28
- y = torch.mm(x, x.T)
29
- print("Tensor allocated on CPU successfully:")
30
- print(y)
31
- except Exception as e:
32
- print("ERROR during CUDA tensor allocation:", str(e))
33
 
 
 
 
34
 
35
- # ---- Gradio 测试接口 ----
36
- import gradio as gr
 
 
 
 
 
37
 
38
- @spaces.GPU
39
- def gpu_test():
40
- if torch.cuda.is_available():
41
- x = torch.rand((2, 3), device="cuda")
42
- return f"GPU OK, tensor sum={x.sum().item()}"
43
- else:
44
- return "No GPU detected, using CPU"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- demo = gr.Interface(fn=gpu_test, inputs=[], outputs="text")
47
  demo.launch()
 
1
+ import spaces
2
  import sys
3
  import torch
 
4
 
5
+ import gradio as gr
6
+ import opencc
7
+
8
+ # 添加第三方库路径
9
+ sys.path.append('third_party/Matcha-TTS')
10
+
11
+ from cosyvoice.cli.cosyvoice import CosyVoice2
12
+ from cosyvoice.utils.file_utils import load_wav
13
 
14
+ from huggingface_hub import hf_hub_download
 
15
 
16
+ # 繁简转换
17
+ converter = opencc.OpenCC('s2t.json')
18
 
19
+ # 加载模型
20
+ cosyvoice_base = CosyVoice2(
21
+ 'ASLP-lab/WSYue-TTS-Cosyvoice2',
22
+ load_jit=False, load_trt=False, load_vllm=False, fp16=False
23
+ )
24
+ print('load model 1')
25
+ cosyvoice_zjg = CosyVoice2(
26
+ 'ASLP-lab/WSYue-TTS-Cosyvoice2-zjg',
27
+ load_jit=False, load_trt=False, load_vllm=False, fp16=False
28
+ )
29
+ print('load model 2')
30
+ # cosyvoice_biaobei = CosyVoice2(
31
+ # 'pretrained_models/CosyVoice2-yue-biaobei',
32
+ # load_jit=False, load_trt=False, load_vllm=False, fp16=False
33
+ # )
34
 
35
+ @spaces.GPU
36
+ def tts_inference(model_choice, text, prompt_audio):
37
+ # 选择模型和默认音频
38
+ if model_choice == "CosyVoice2-张悦楷粤语评书":
39
+ model = cosyvoice_zjg
40
+ prompt_audio = "asset/sg_017_090.wav"
41
+ elif model_choice == "CosyVoice2-精品女音":
42
+ model = cosyvoice_base
43
+ prompt_audio = "asset/F01_中立_20054.wav"
44
+ elif model_choice == "CosyVoice2-base":
45
+ model = cosyvoice_base
46
+ if prompt_audio is None:
47
+ return None, "请上传参考音频"
48
  else:
49
+ return None, "未知模型"
50
+
51
+ model.eval().cuda()
 
 
 
 
52
 
53
+ # 繁简转换
54
+ text = converter.convert(text)
55
+ prompt_speech_16k = load_wav(prompt_audio, 16000)
56
 
57
+ all_speech = []
58
+ for _, j in enumerate(
59
+ model.inference_instruct2(
60
+ text, "用粤语说这句话", prompt_speech_16k, stream=False
61
+ )
62
+ ):
63
+ all_speech.append(j['tts_speech'])
64
 
65
+ concatenated_speech = torch.cat(all_speech, dim=1)
66
+ audio_numpy = concatenated_speech.squeeze(0).cpu().numpy()
67
+ sample_rate = model.sample_rate
68
+
69
+ return (sample_rate, audio_numpy), f"生成成功:{text}"
70
+
71
+
72
+ # ---- Gradio Interface ----
73
+ demo = gr.Interface(
74
+ fn=tts_inference,
75
+ inputs=[
76
+ gr.Dropdown(
77
+ ["CosyVoice2-base", "CosyVoice2-张悦楷粤语评书"],
78
+ # ["CosyVoice2-base", "CosyVoice2-张悦楷粤语评书", "CosyVoice2-精品女音"],
79
+ label="选择模型", value="CosyVoice2-base"
80
+ ),
81
+ gr.Textbox(lines=2, label="输入文本"),
82
+ # gr.Audio(source="upload", type="filepath", label="上传参考音频(仅 CosyVoice2-base 必需)")
83
+ gr.Audio(sources=["upload"], type="filepath", label="上传参考音频(仅 CosyVoice2-base 必需)")
84
+ ],
85
+ outputs=[
86
+ gr.Audio(type="numpy", label="生成的语音"),
87
+ gr.Textbox(label="状态信息")
88
+ ]
89
+ )
90
 
 
91
  demo.launch()