lataon commited on
Commit
5adf0a6
·
1 Parent(s): 9bf2487

add: phoneme eval code

Browse files
Files changed (3) hide show
  1. Makefile +5 -0
  2. requirements.txt +5 -1
  3. src/phoneme_eval.py +186 -0
Makefile CHANGED
@@ -11,3 +11,8 @@ quality:
11
  python -m black --check --line-length 119 .
12
  python -m isort --check-only .
13
  ruff check .
 
 
 
 
 
 
11
  python -m black --check --line-length 119 .
12
  python -m isort --check-only .
13
  ruff check .
14
+
15
+
16
+ .PHONY: eval
17
+ eval:
18
+ python -m src.phoneme_eval
requirements.txt CHANGED
@@ -13,4 +13,8 @@ python-dateutil
13
  tqdm
14
  transformers
15
  tokenizers>=0.15.0
16
- sentencepiece
 
 
 
 
 
13
  tqdm
14
  transformers
15
  tokenizers>=0.15.0
16
+ sentencepiece
17
+ torchaudio
18
+ torch
19
+ nltk
20
+ g2p-en
src/phoneme_eval.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from dataclasses import dataclass
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torchaudio
9
+ from datasets import load_dataset
10
+ from transformers import (
11
+ Wav2Vec2Processor,
12
+ HubertForCTC,
13
+ Wav2Vec2ForCTC,
14
+ )
15
+
16
+
17
+ @dataclass
18
+ class EvalConfig:
19
+ dataset_name: str = "mirfan899/phoneme_asr"
20
+ split: str = "train"
21
+ max_examples: int = 100
22
+ results_dir: str = "eval-results" # relative to CWD
23
+ model_sha: str = ""
24
+ model_dtype: str = "float16"
25
+
26
+
27
+ def load_audio_array(example):
28
+ return example["audio"]["array"]
29
+
30
+
31
+ def load_models(device: torch.device):
32
+ base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
33
+ base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval()
34
+
35
+ timit_proc = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
36
+ timit_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme").to(device).eval()
37
+
38
+ return (base_proc, base_model), (timit_proc, timit_model)
39
+
40
+
41
+ def clean_cmu(text: str) -> str:
42
+ res = text.replace("0", "").replace("1", "").replace("2", "").replace("-", "").strip()
43
+ return res.lower()
44
+
45
+
46
+ def cmu_to_ipa(cmu_sentence: str) -> str:
47
+ cmu_map = {
48
+ "AA": "ɑ", "AE": "æ", "AH": "ʌ", "AH0": "ə", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
49
+ "EH": "ɛ", "ER": "ɝ", "ER0": "ɚ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
50
+ "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "tʃ", "D": "d", "DH": "ð",
51
+ "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k", "L": "l", "M": "m",
52
+ "N": "n", "NG": "ŋ", "P": "p", "R": "r", "S": "s", "SH": "ʃ", "T": "t",
53
+ "TH": "θ", "V": "v", "W": "w", "Y": "j", "Z": "z", "ZH": "ʒ",
54
+ }
55
+ ipa_tokens = []
56
+ for word in cmu_sentence.strip().split():
57
+ i = 0
58
+ while i < len(word):
59
+ if i + 2 <= len(word) and word[i:i+2].upper() in cmu_map:
60
+ ipa_tokens.append(cmu_map[word[i:i+2].upper()]); i += 2
61
+ elif word[i].upper() in cmu_map:
62
+ ipa_tokens.append(cmu_map[word[i].upper()]); i += 1
63
+ else:
64
+ ipa_tokens.append(word[i].lower()); i += 1
65
+ ipa_tokens.append(" ")
66
+ return "".join(ipa_tokens)
67
+
68
+
69
+ def align_sequences(seq1: str, seq2: str):
70
+ n, m = len(seq1), len(seq2)
71
+ dp = np.zeros((n + 1, m + 1), dtype=np.float32)
72
+ back = np.empty((n + 1, m + 1), dtype="U1")
73
+ dp[:, 0] = np.arange(n + 1)
74
+ dp[0, :] = np.arange(m + 1)
75
+ back[:, 0] = "D"; back[0, :] = "I"; back[0, 0] = ""
76
+ for i in range(1, n + 1):
77
+ for j in range(1, m + 1):
78
+ cost = 0.0 if seq1[i - 1] == seq2[j - 1] else 1.0
79
+ opts = [(dp[i - 1][j] + 1, "D"), (dp[i][j - 1] + 1, "I"), (dp[i - 1][j - 1] + cost, "M")]
80
+ dp[i][j], back[i][j] = min(opts, key=lambda x: x[0])
81
+ i, j = n, m; a1, a2 = [], []
82
+ while i > 0 or j > 0:
83
+ mv = back[i][j]
84
+ if mv == "M": a1.append(seq1[i - 1]); a2.append(seq2[j - 1]); i -= 1; j -= 1
85
+ elif mv == "D": a1.append(seq1[i - 1]); a2.append("-"); i -= 1
86
+ elif mv == "I": a1.append("-"); a2.append(seq2[j - 1]); j -= 1
87
+ else: break
88
+ a1.reverse(); a2.reverse(); return a1, a2
89
+
90
+
91
+ def calculate_per(ref_seq: str, hyp_seq: str) -> float:
92
+ ref_seq = ref_seq.replace(" ", ""); hyp_seq = hyp_seq.replace(" ", "")
93
+ aligned_ref, aligned_hyp = align_sequences(ref_seq, hyp_seq)
94
+ s = d = i = 0
95
+ for r, h in zip(aligned_ref, aligned_hyp):
96
+ if r == h: continue
97
+ if r == "-": i += 1
98
+ elif h == "-": d += 1
99
+ else: s += 1
100
+ n = len(ref_seq)
101
+ return ((s + d + i) / n) * 100.0 if n > 0 else 0.0
102
+
103
+
104
+ def run_hubert_base(proc, model, wav, device):
105
+ inputs = proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
106
+ with torch.no_grad():
107
+ logits = model(inputs).logits
108
+ ids = torch.argmax(logits, dim=-1)
109
+ text = proc.batch_decode(ids)[0]
110
+ return text
111
+
112
+
113
+ def run_timit(proc, model, wav, device):
114
+ inputs = proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).to(device)
115
+ with torch.no_grad():
116
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
117
+ ids = torch.argmax(logits, dim=-1)
118
+ ph = proc.batch_decode(ids)
119
+ return "".join(ph)
120
+
121
+
122
+ def evaluate(config: EvalConfig):
123
+ os.makedirs(config.results_dir, exist_ok=True)
124
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125
+
126
+ (base_proc, base_model), (timit_proc, timit_model) = load_models(device)
127
+
128
+ ds = load_dataset(config.dataset_name, split=config.split)
129
+ uniq = set(ds.unique("phonetic"))
130
+ ds = ds.filter(lambda x: x["phonetic"] in uniq)
131
+ ds = ds.filter(lambda x: len(x["phonetic"].split()) >= 10)
132
+ ds = ds.shuffle(seed=42).select(range(min(config.max_examples, len(ds))))
133
+
134
+ results = {
135
+ "results": {
136
+ "phoneme_dev": {},
137
+ "phoneme_test": {},
138
+ },
139
+ "config": {
140
+ "model_name": "phoneme/baselines",
141
+ "model_sha": config.model_sha,
142
+ "model_dtype": config.model_dtype,
143
+ },
144
+ }
145
+
146
+ # Simple split into dev/test halves
147
+ mid = len(ds) // 2
148
+ halves = [("phoneme_dev", ds.select(range(0, mid))), ("phoneme_test", ds.select(range(mid, len(ds))))]
149
+
150
+ for split_key, subset in halves:
151
+ per_scores_hubert = []
152
+ per_scores_timit = []
153
+ for ex in subset:
154
+ wav = ex["audio"]["array"]
155
+ ref = cmu_to_ipa(clean_cmu(ex["phonetic"]))
156
+
157
+ # HuBERT base → CMU→IPA
158
+ base_pred_cmu = run_hubert_base(base_proc, base_model, wav, device)
159
+ base_pred_ipa = cmu_to_ipa(base_pred_cmu)
160
+ per_scores_hubert.append(calculate_per(ref, base_pred_ipa))
161
+
162
+ # TIMIT phoneme model (already phoneme-like)
163
+ timit_pred = run_timit(timit_proc, timit_model, wav, device)
164
+ timit_pred_ipa = timit_pred # leave as-is
165
+ per_scores_timit.append(calculate_per(ref, timit_pred_ipa))
166
+
167
+ # record mean PER per model under this split
168
+ results["results"][split_key] = {
169
+ "hubert_base": {"per": float(np.mean(per_scores_hubert)) if per_scores_hubert else None},
170
+ "timit_model": {"per": float(np.mean(per_scores_timit)) if per_scores_timit else None},
171
+ }
172
+
173
+ # Save a single combined result file
174
+ ts = int(time.time())
175
+ out_path = os.path.join(config.results_dir, f"results_{ts}.json")
176
+ with open(out_path, "w", encoding="utf-8") as f:
177
+ json.dump(results, f, ensure_ascii=False, indent=2)
178
+ return out_path
179
+
180
+
181
+ if __name__ == "__main__":
182
+ cfg = EvalConfig()
183
+ path = evaluate(cfg)
184
+ print(f"Saved results to {path}")
185
+
186
+