fisherman611 commited on
Commit
fec5dda
·
verified ·
1 Parent(s): c6f9ba7

Upload 3 files

Browse files
Files changed (3) hide show
  1. models/mt5.py +122 -0
  2. models/rule_based_mt.py +470 -0
  3. models/statistical_mt.py +884 -0
models/mt5.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
+
6
+ import torch
7
+ from transformers import MT5TokenizerFast, MT5ForConditionalGeneration # type: ignore
8
+ from datasets import load_dataset
9
+ from peft import LoraConfig, get_peft_model, TaskType
10
+ from dotenv import load_dotenv
11
+ import wandb
12
+ import json
13
+ from utils.helper import TextPreprocessor
14
+ from utils.trainer import train_model
15
+
16
+ load_dotenv()
17
+
18
+
19
+ class MT5Finetuner:
20
+ """Class to handle fine-tuning of mT5 model for translation tasks."""
21
+
22
+ def __init__(self, config_path="config.json"):
23
+ """Initialize with configuration file."""
24
+ with open(config_path, "r") as json_file:
25
+ cfg = json.load(json_file)
26
+
27
+ self.args = cfg["mt5"]["args"]
28
+ self.lora_config = cfg["mt5"]["lora_config"]
29
+
30
+ # Constants
31
+ self.max_len = self.args["max_len"]
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ self.id = self.args["id"]
34
+ self.initial_learning_rate = self.args["initial_learning_rate"]
35
+ self.model_name = self.args["model_name"]
36
+ self.wandb_project = self.args["wandb_project"]
37
+ self.output_dir = self.args["output_dir"]
38
+ self.name = "mt5"
39
+
40
+ self.model = None
41
+ self.tokenizer = None
42
+ self.train_dataset = None
43
+ self.val_dataset = None
44
+ self.test_dataset = None
45
+
46
+ def setup_wandb(self):
47
+ """Initialize Weights & Biases for experiment tracking."""
48
+ wandb.login(key=os.environ.get("WANDB_API"), relogin=True)
49
+ wandb.init(project=self.wandb_project, name="mt5-finetune-lora")
50
+
51
+ def load_model_and_tokenizer(self):
52
+ """Load the mT5 model and tokenizer."""
53
+ self.tokenizer = MT5TokenizerFast.from_pretrained(self.model_name, legacy=False)
54
+ self.model = MT5ForConditionalGeneration.from_pretrained(self.model_name)
55
+ self.model.config.use_cache = False # Disable cache for training
56
+
57
+ def load_datasets(self):
58
+ """Load training, validation, and test datasets."""
59
+ data_files = {
60
+ "train": "data/train_cleaned_dataset.csv",
61
+ "test": "data/test_cleaned_dataset.csv",
62
+ "val": "data/val_cleaned_dataset.csv",
63
+ }
64
+
65
+ if self.id is not None:
66
+ training_parts = [
67
+ f"[{(i * 200000) + 1 if i > 0 else ''}:{(i + 1) * 200000 if i < 10 else ''}]"
68
+ for i in range(11)
69
+ ]
70
+ self.train_dataset = load_dataset(
71
+ "csv", data_files=data_files, split=f"train{training_parts[self.id]}"
72
+ )
73
+ self.test_dataset = load_dataset("csv", data_files=data_files, split="test")
74
+ self.val_dataset = load_dataset(
75
+ "csv", data_files=data_files, split="val[:20000]"
76
+ )
77
+ else:
78
+ self.train_dataset = load_dataset(
79
+ "csv", data_files=data_files, split="train[:1000000]"
80
+ )
81
+ self.test_dataset = load_dataset("csv", data_files=data_files, split="test[:100000]")
82
+ self.val_dataset = load_dataset("csv", data_files=data_files, split="val[:100000]")
83
+
84
+ def configure_lora(self):
85
+ """Apply LoRA configuration to the model."""
86
+ lora_config = LoraConfig(
87
+ task_type=TaskType.SEQ_2_SEQ_LM,
88
+ r=self.lora_config["r"],
89
+ lora_alpha=self.lora_config["lora_alpha"],
90
+ target_modules=self.lora_config["target_modules"],
91
+ lora_dropout=self.lora_config["lora_dropout"],
92
+ )
93
+ self.model = get_peft_model(self.model, lora_config) # type: ignore
94
+
95
+ def finetune(self):
96
+ """Orchestrate the fine-tuning process."""
97
+ self.setup_wandb()
98
+ self.load_model_and_tokenizer()
99
+ self.load_datasets()
100
+
101
+ preprocessor = TextPreprocessor(self.tokenizer, self.max_len, name="mt5")
102
+ tokenized_train_dataset = preprocessor.preprocess_dataset(self.train_dataset)
103
+ tokenized_eval_dataset = preprocessor.preprocess_dataset(self.val_dataset)
104
+
105
+ self.configure_lora()
106
+ self.model.print_trainable_parameters() # type: ignore
107
+
108
+ train_model(
109
+ model=self.model,
110
+ tokenizer=self.tokenizer,
111
+ train_dataset=tokenized_train_dataset,
112
+ eval_dataset=tokenized_eval_dataset,
113
+ output_dir=self.output_dir,
114
+ initial_learning_rate=self.initial_learning_rate,
115
+ name=self.name,
116
+ val_dataset=self.val_dataset,
117
+ )
118
+
119
+
120
+ if __name__ == "__main__":
121
+ finetuner = MT5Finetuner()
122
+ finetuner.finetune()
models/rule_based_mt.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
+
6
+ import re
7
+ import nltk
8
+ from nltk.tokenize import word_tokenize
9
+ from nltk.tag import pos_tag
10
+ from nltk.parse import ChartParser, ViterbiParser
11
+ from nltk.grammar import CFG, PCFG, Nonterminal, ProbabilisticProduction
12
+ from nltk.tree import Tree
13
+ import contractions
14
+ import string
15
+ from collections import defaultdict
16
+ import spacy
17
+
18
+ nlp = spacy.load("en_core_web_sm")
19
+
20
+ import json
21
+
22
+ with open("data/en_vi_dictionary.json", "r", encoding='utf-8') as json_file:
23
+ dictionary = json.load(json_file)
24
+
25
+ with open('grammar.txt', 'r', encoding='utf-8') as text_file:
26
+ grammar = text_file.read()
27
+
28
+
29
+ class TransferBasedMT:
30
+
31
+ def __init__(self) -> None:
32
+ # English - Vietnamese dictionary
33
+ self.dictionary = dictionary
34
+
35
+ # Define the CFG grammar for English sentence structure
36
+ self.grammar = grammar
37
+
38
+
39
+ ################################################ STAGE 1: PREPROCESSING SOURCE SENTENCE ###################################################
40
+
41
+ def preprocessing(self, sentence: str) -> str:
42
+ """Preprocess the input sentence: handle named entities, lowercase, expand contractions, and tokenize and regroup."""
43
+ # Handle named entities, e.g. New York -> New_York
44
+ doc = nlp(sentence)
45
+ entities = {ent.text: ent.label_ for ent in doc.ents}
46
+ for ent_text in sorted(entities.keys(), key=len,reverse=True):
47
+ ent_joined = ent_text.replace(" ", "_")
48
+ sentence = sentence.replace(ent_text, ent_joined)
49
+
50
+ # Lowercase and strip redundant space
51
+ sentence = sentence.lower().strip()
52
+
53
+ # Expand contractions, e.g. don't -> do not
54
+ sentence = contractions.fix(sentence) #type: ignore
55
+
56
+ # Tokenize and regroup tokens
57
+ sentence = " ".join(word_tokenize(sentence))
58
+
59
+ return sentence
60
+
61
+
62
+ def safe_tag(self, tag):
63
+ """Convert tags with special characters to safe nonterminal symbols."""
64
+ return tag.replace("$", "S")
65
+
66
+
67
+ ################################################ STAGE 2: ANALYZE SOURCE SENTENCE #########################################################
68
+
69
+ def analyze_source(self, sentence: str):
70
+ """Analyze the source sentence: tokenize, POS tag, and parse into a syntax tree."""
71
+ doc = nlp(sentence)
72
+ filtered_pos_tagged = []
73
+ punctuation_marks = []
74
+
75
+ for i, token in enumerate(doc):
76
+ word = token.text
77
+ tag = token.tag_
78
+ if all(char in string.punctuation for char in word):
79
+ punctuation_marks.append((i, word, tag))
80
+ else:
81
+ filtered_pos_tagged.append((token.lemma_.lower(), tag))
82
+
83
+ grammar_str = self.grammar
84
+
85
+ # Add terminal rule grammars
86
+ for word, tag in filtered_pos_tagged:
87
+ safe_tag = self.safe_tag(tag)
88
+ escaped_word = word.replace('"', '\\"')
89
+ grammar_str += f'\n{safe_tag} -> "{escaped_word}"'
90
+
91
+ try:
92
+ grammar = CFG.fromstring(grammar_str)
93
+ parser = ChartParser(grammar)
94
+ tagged_tokens_only = [word for word, _ in filtered_pos_tagged]
95
+
96
+ parses = list(parser.parse(tagged_tokens_only)) # Generate parse trees
97
+
98
+ tree = (parses[0] if parses else self._create_fallback_tree(filtered_pos_tagged)) # Use first parse or fallback
99
+ tree = self._add_punctuation_to_tree(tree, punctuation_marks) # Reattach punctuation
100
+
101
+ return tree
102
+
103
+ except Exception as e:
104
+ print(f"Grammar creation error: {e}")
105
+ return self._create_fallback_tree(filtered_pos_tagged) # Fallback on error
106
+
107
+
108
+ def _create_fallback_tree(self, pos_tagged):
109
+ """Create a simple fallback tree when parsing fails."""
110
+ children = [Tree(self.safe_tag(tag), [word]) for word, tag in pos_tagged] # Create leaf nodes for each token
111
+ return Tree("S", children) # Wrap in a sentence node
112
+
113
+
114
+ def _add_punctuation_to_tree(self, tree, punctuation_marks):
115
+ """Add punctuation marks back to the syntax tree."""
116
+ if not punctuation_marks:
117
+ return tree
118
+ if tree.label() == "S": # Only add to root sentence node
119
+ for _, word, tag in sorted(punctuation_marks):
120
+ tree.append(Tree(self.safe_tag(tag), [word]))
121
+ return tree
122
+
123
+
124
+ #################################################### STAGE 3: TRANSFER GRAMMAR ############################################################
125
+
126
+ def transfer_grammar(self, tree):
127
+ """Transfer the English parse tree to Vietnamese structure."""
128
+ if not isinstance(tree, nltk.Tree):
129
+ return tree
130
+
131
+ # Sentence level: recurse through children
132
+ if tree.label() == "S":
133
+ return Tree("S", [self.transfer_grammar(child) for child in tree])
134
+
135
+ # Verb Phrase: adjust word order
136
+ elif tree.label() == "VP":
137
+ children = [self.transfer_grammar(child) for child in tree]
138
+ child_labels = [child.label() if isinstance(child, Tree) else child for child in children]
139
+
140
+ if (len(children) >= 3 and "V" in child_labels and "To" in child_labels and "VP" in child_labels): # Remove TO from V TO VP
141
+ return Tree("VP", [children[0], children[2]])
142
+
143
+ return Tree("VP", children) # Default: preserve order
144
+
145
+ # Noun Phrase: adjust word order
146
+ elif tree.label() == "NP":
147
+ children = [self.transfer_grammar(child) for child in tree]
148
+ child_labels = [child.label() if isinstance(child, Tree) else child for child in children]
149
+
150
+ if (len(children) >= 3 and 'Det' in child_labels and 'AdjP' in child_labels and 'N' in child_labels): # Reorder Det Adj N -> Det N Adj
151
+ return Tree("NP", [children[0], children[2], children[1]])
152
+
153
+ elif (len(children) >= 2 and 'PRPS' in child_labels and 'N' in child_labels): # Reorder PRPS N -> N PRPS
154
+ return Tree("NP", [children[1], children[0]])
155
+
156
+ elif (len(children) >= 2 and 'Det' in child_labels and 'N' in child_labels): # Remove Det from Det N
157
+ return Tree("NP", [children[1]])
158
+
159
+ return Tree("NP", children) # Default: preserve order
160
+
161
+ # Prepositional Phrase: adjust word order
162
+ elif tree.label() == "PP":
163
+ children = [self.transfer_grammar(child) for child in tree]
164
+ return Tree("PP", children) # Default: preserve order
165
+
166
+ # Adverbial Phrase: adjust word order
167
+ elif tree.label() == 'AdvP':
168
+ children = [self.transfer_grammar(child) for child in tree]
169
+ return Tree("AdvP", children) # Default: preserve order
170
+
171
+ # Adjective Phrase: adjust word order
172
+ elif tree.label() == 'AdjP':
173
+ children = [self.transfer_grammar(child) for child in tree]
174
+ return Tree("AdjP", children) # Default: preserve order
175
+
176
+ # Wh-Question: adjust word order
177
+ elif tree.label() == "WhQ":
178
+ children = [self.transfer_grammar(child) for child in tree]
179
+ child_labels = [child.label() if isinstance(child, Tree) else child for child in children]
180
+
181
+ if len(children) >= 4 and "WH_Word" in child_labels and "AUX" in child_labels and "NP" in child_labels and "VP" in child_labels:
182
+ return Tree("WhQ", [children[2], children[3], children[0]]) # Remove AUX from WH_Word AUX NP VP
183
+
184
+ elif len(children) >= 3 and "WH_Word" in child_labels and "NP" in child_labels and "VP" in child_labels and "AUX" not in child_labels:
185
+ return Tree("WhQ", [children[1], children[2], children[0]])
186
+
187
+ elif len(children) >= 2 and "WH_Word" in child_labels and "VP" in child_labels:
188
+ if len(children[1]) >= 2:
189
+ return Tree("WhQ", [children[1][1], children[1][0], children[0]]) # WH_Word VP -> WH_Word V NP
190
+
191
+ else:
192
+ return Tree("WhQ", children) # Default: preserve order
193
+
194
+ # Yes/No-Question: adjust word order
195
+ elif tree.label() == "YNQ":
196
+ children = [self.transfer_grammar(child) for child in tree]
197
+ child_labels = [child.label() if isinstance(child, Tree) else child for child in children]
198
+
199
+ if len(children) >= 3 and "AUX" in child_labels and "NP" in child_labels and "VP" in child_labels:
200
+ return Tree("YNQ", [children[1], children[2]])
201
+
202
+ elif len(children) >= 3 and "DO" in child_labels and "NP" in child_labels and "VP" in child_labels:
203
+ return Tree("YNQ", [children[1], children[2]])
204
+
205
+ elif len(children) >= 3 and "MD" in child_labels and "NP" in child_labels and "VP" in child_labels:
206
+ return Tree("YNQ", [children[1], children[2]])
207
+
208
+ return Tree("YNQ", children)
209
+
210
+
211
+ # Other labels: recurse through children
212
+ else:
213
+ return Tree(tree.label(), [self.transfer_grammar(child) for child in tree])
214
+
215
+
216
+ #################################################### STAGE 4: GENERATION STAGE ############################################################
217
+
218
+ def generate(self, tree):
219
+ """Generate Vietnamese output from the transformed tree."""
220
+ if not isinstance(tree, nltk.Tree):
221
+ return self._lexical_transfer(tree) # Translate leaf nodes
222
+
223
+ words = [self.generate(child) for child in tree if self.generate(child)] # Recurse
224
+
225
+ # Handle questions specifically
226
+ if tree.label() == "WhQ":
227
+ words = self._process_wh_question(tree, words)
228
+ elif tree.label() == "YNQ":
229
+ words = self._process_yn_question(tree, words)
230
+ elif tree.label() == "NP": # Add classifiers for nouns
231
+ words = self._add_classifiers(tree, words)
232
+ elif tree.label() == "VP": # Apply tense/aspect/mood markers
233
+ words = self._apply_tam_mapping(tree, words)
234
+
235
+ words = self._apply_agreement(tree, words) # Handle agreement (e.g., plurals)
236
+ result = " ".join(words) # Join words into a string
237
+
238
+ return result
239
+
240
+
241
+ def _process_wh_question(self, tree, words):
242
+ """Process a Wh-question structure for Vietnamese."""
243
+ words = [w for w in words if w]
244
+
245
+ wh_word = None
246
+ for word in words:
247
+ if word in ["cái gì", "ai", "ở đâu", "khi nào", "tại sao", "như thế nào", "cái nào", "của ai"]:
248
+ wh_word = word
249
+ break
250
+
251
+ if wh_word == "tại sao":
252
+ if words and words[0] != "tại sao":
253
+ words.remove("tại sao")
254
+ words.insert(0, "tại sao")
255
+ elif wh_word == "như thế nào":
256
+ if "vậy" not in words:
257
+ words.append("vậy")
258
+
259
+ question_particles = ["vậy", "thế", "à", "hả"]
260
+ has_particle = any(particle in words for particle in question_particles)
261
+
262
+ if not has_particle and wh_word != "tại sao":
263
+ words.append("vậy")
264
+
265
+ return words
266
+
267
+
268
+ def _process_yn_question(self, tree, words):
269
+ """Process a Yes/No question structure for Vietnamese."""
270
+
271
+ words = [w for w in words if w not in ["", "do_vn", "does_vn", "did_vn"]]
272
+
273
+ has_question_particle = any(w in ["không", "à", "hả", "nhỉ", "chứ"] or
274
+ w in ["không_vn", "à_vn", "hả_vn", "nhỉ_vn", "chứ_vn"]
275
+ for w in words)
276
+
277
+ if not has_question_particle:
278
+ if "đã" in words or "đã_vn" in words:
279
+ words.append("phải không")
280
+ else:
281
+ words.append("không")
282
+
283
+ return words
284
+
285
+
286
+ def _lexical_transfer(self, word):
287
+ """Translate English words to Vietnamese using the dictionary."""
288
+ if word in self.dictionary:
289
+ return self.dictionary[word] # Return translation if in dictionary
290
+ return f"{word}_vn" # Mark untranslated words with _vn suffix
291
+
292
+
293
+ def _add_classifiers(self, np_tree, words):
294
+ """Add Vietnamese classifiers based on nouns."""
295
+ # noun_indices = [
296
+ # i for i, child in enumerate(np_tree) if isinstance(child, Tree)
297
+ # and child.label() in ["N", "NN", "NNS", "NNP", "NNPS"]
298
+ # ] # Find noun positions
299
+ # for i in noun_indices:
300
+ # if len(words) > i and not any(words[i].startswith(prefix) for prefix in ["một_vn", "những_vn", "các_vn"]): # Check if classifier is needed
301
+ # if words[i].endswith("_vn"): # Add default classifier for untranslated nouns
302
+ # words.insert(i, "cái_vn")
303
+ return words
304
+
305
+
306
+ def _apply_tam_mapping(self, vp_tree, words):
307
+ """Apply Vietnamese TAM (Tense, Aspect, Mood) markers to the word list.
308
+
309
+ Args:
310
+ vp_tree: A parse tree node representing the verb phrase.
311
+ words: List of words to be modified with TAM markers.
312
+
313
+ Returns:
314
+ List of words with appropriate Vietnamese TAM markers inserted.
315
+ """
316
+ verb_tense = None
317
+ mood = None
318
+
319
+ # Identify verb tense and mood from the verb phrase tree
320
+ for child in vp_tree:
321
+ if isinstance(child, Tree):
322
+ if child.label() in ["V", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]:
323
+ verb_tense = child.label()
324
+ if child.label() == "MD": # Modal verbs indicating mood
325
+ mood = "indicative"
326
+ elif child.label() == "TO": # Infinitive marker, often subjunctive
327
+ mood = "subjunctive"
328
+
329
+ if not verb_tense:
330
+ print("Warning: No verb tense identified in the verb phrase tree.")
331
+ return words
332
+
333
+ # Apply TAM markers based on verb tense
334
+ if verb_tense == "VBD":
335
+ words.insert(0, "đã_vn")
336
+ elif verb_tense == "VB":
337
+ if "will_vn" in words:
338
+ words = [w for w in words if w != "will_vn"]
339
+ words.insert(0, "sẽ_vn")
340
+ elif "going_to_vn" in words:
341
+ words = [w for w in words if w != "going_to_vn"]
342
+ words.insert(0, "sẽ_vn")
343
+ elif verb_tense == "VBG":
344
+ words.insert(0, "đang_vn")
345
+ if "đã_vn" in words:
346
+ words.insert(0, "đã_vn")
347
+ elif verb_tense == "VBN":
348
+ words.insert(0, "đã_vn")
349
+ elif verb_tense == "VBP" or verb_tense == "VBZ":
350
+ pass
351
+
352
+ # Handle future continuous (e.g., "will be running" -> "sẽ đang")
353
+ if verb_tense == "VBG" and "will_vn" in words:
354
+ words = [w for w in words if w != "will_vn"]
355
+ words.insert(0, "đang_vn") # Continuous marker
356
+ words.insert(0, "sẽ_vn") # Future marker
357
+
358
+ # Apply mood markers if applicable
359
+ if mood == "subjunctive":
360
+ words.insert(0, "nếu_vn") # Subjunctive marker (e.g., "if" clause)
361
+ elif mood == "indicative" and "must_vn" in words:
362
+ words = [w for w in words if w != "must_vn"]
363
+ words.insert(0, "phải_vn") # Necessity marker
364
+
365
+ return words
366
+
367
+
368
+ def _apply_agreement(self, tree, words):
369
+ """Apply agreement rules for Vietnamese (e.g., pluralization)."""
370
+ if tree.label() == "NP":
371
+ for i, word in enumerate(words):
372
+ if "_vn" in word and word.replace("_vn", "").endswith("s"): # Handle English plurals
373
+ base_word = word.replace("_vn", "")[:-1] + "_vn" # Remove 's'
374
+ words[i] = base_word
375
+ words.insert(i, "các_vn") # Add plural marker
376
+ return words
377
+
378
+
379
+ def _post_process_vietnamese(self, text):
380
+ """Post-process the Vietnamese output: remove _vn, fix punctuation, capitalize."""
381
+ text = text.replace("_vn", "") # Remove untranslated markers
382
+
383
+ def fix_entities(word):
384
+ if "_" in word:
385
+ word = " ".join([w for w in word.split("_")])
386
+ return word.title()
387
+ return word.lower() # Lowercase non-entity words
388
+
389
+ words = text.split()
390
+ words = [fix_entities(word) for word in words]
391
+
392
+ text = " ".join(words)
393
+ for punct in [".", ",", "!", "?", ":", ";"]: # Attach punctuation directly
394
+ text = text.replace(f" {punct}", punct)
395
+
396
+ if text:
397
+ words = text.split()
398
+ words[0] = words[0].capitalize() # Capitalize first word
399
+ text = ' '.join(words)
400
+ return text
401
+
402
+
403
+ def translate(self, english_sentence):
404
+ """Main translation function that applies all stages of the process."""
405
+ # Step 1: Preprocess input
406
+ preprocessed = self.preprocessing(english_sentence)
407
+
408
+ # Step 2: Parse English sentence
409
+ source_tree = self.analyze_source(preprocessed)
410
+ print("English parse tree:")
411
+ source_tree.pretty_print() # Display English parse tree
412
+
413
+ # Step 3: Transform to Vietnamese structure
414
+ target_tree = self.transfer_grammar(source_tree)
415
+ print("Vietnamese structure tree:")
416
+ target_tree.pretty_print() # Display Vietnamese parse tree
417
+
418
+ # Step 4: Generate final translation
419
+ raw_output = self.generate(target_tree)
420
+ vietnamese_output = self._post_process_vietnamese(raw_output)
421
+ return vietnamese_output
422
+
423
+
424
+ if __name__ == "__main__":
425
+ translator = TransferBasedMT()
426
+ test_sentences = [
427
+ "I read books.", "The student studies at school.",
428
+ "She has a beautiful house.", "They want to buy a new car.",
429
+ "This is a good computer.", "Are you ready to listen?",
430
+ "I want to eat.", "This is my book.","What is your name?",
431
+ "Do you like books?",
432
+ "Is she at school?",
433
+ "Are you ready to listen?",
434
+ "Can they buy a new car?",
435
+ "Did he read the book yesterday?",
436
+ "What is your name?",
437
+ "Where do you live?",
438
+ "Who is your teacher?",
439
+ "When will you go to school?",
440
+ "Why did he leave early?",
441
+ "How do you feel today?",
442
+ "I live in New York"
443
+ ]
444
+
445
+ test_sentences_2 = [
446
+ # YNQ -> BE NP
447
+ "Is the renowned astrophysicist still available for the conference?",
448
+ "Are those adventurous explorers currently in the remote jungle?",
449
+ "Was the mysterious stranger already gone by midnight?",
450
+ # YNQ -> BE NP Adj
451
+ "Is the vibrant annual festival exceptionally spectacular this season?",
452
+ "Are the newly discovered species remarkably resilient to harsh climates?",
453
+ "Were the ancient ruins surprisingly well-preserved after centuries?",
454
+ # YNQ -> BE NP NP
455
+ "Is she the brilliant leader of the innovative research team?",
456
+ "Are they the enthusiastic organizers of the grand charity event?",
457
+ "Was he the sole survivor of the perilous expedition?",
458
+ # YNQ -> BE NP PP
459
+ "Is the priceless artifact still hidden in the ancient underground chamber?",
460
+ "Are the colorful tropical birds nesting high above the lush rainforest canopy?",
461
+ "Was the historic manuscript carefully stored within the fortified library vault?"
462
+ ]
463
+
464
+ print("English to Vietnamese Translation Examples:")
465
+ print("-" * 50)
466
+ for sentence in test_sentences_2:
467
+ print(f"English: {sentence}")
468
+ translation = translator.translate(sentence)
469
+ print(f"Vietnamese: {translation}")
470
+ print()
models/statistical_mt.py ADDED
@@ -0,0 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from nltk.translate import AlignedSent
3
+ from nltk.translate.ibm1 import IBMModel1
4
+ from nltk.lm import MLE
5
+ from nltk.lm.preprocessing import padded_everygram_pipeline
6
+ from collections import defaultdict, Counter
7
+ import math
8
+ import os
9
+ from tqdm import tqdm
10
+ import pickle
11
+ import random
12
+ import gc
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import contractions
16
+ BILINGUAL_DATA_PATH = "bilingual_cleaned_dataset.csv" # Default bilingual dataset path
17
+ VIE_DATA_PATH = "vie_cleaned_dataset.csv" # Default Vietnamese dataset path
18
+ VISUALIZATION_PATH = "visualizations" # Default visualization output path
19
+ BEAM_SIZE = 3
20
+ MAX_PHRASE_LENGTH = 7
21
+ LM_ORDER = 3
22
+ ALPHA = 0.7
23
+ BETA = 0.3
24
+ BATCH_SIZE = 1000 # For processing data in batches
25
+ MIN_PHRASE_COUNT = 3 # Increased threshold to reduce phrase table size
26
+ LIMIT_VOCAB = 100000 # Limit vocabulary size to 10 words
27
+ MODE_VISUALIZATION = False # Enable visualization
28
+ from pyvi import ViTokenizer
29
+ from nltk.tokenize import word_tokenize
30
+
31
+
32
+
33
+
34
+ ################################################## 1. Language Model ##################################################
35
+ class LanguageModel:
36
+ """Memory-optimized Language Model"""
37
+ def __init__(self, order=LM_ORDER, MODE_VISUALIZATION=MODE_VISUALIZATION):
38
+ self.order = order
39
+ self.lm = None
40
+ self.vocab_size = 0
41
+ self.MODE_VISUALIZATION = MODE_VISUALIZATION
42
+
43
+ def preprocess(self, text):
44
+ """Tokenize Vietnamese words"""
45
+ # return text.lower().split()
46
+ return ViTokenizer.tokenize(text.lower()).split()
47
+
48
+ def visualize_iterations(self, word_freq, iteration, batch_tokens, output_dir="/kaggle/working/visualizations"):
49
+ if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
50
+ # Đang chạy trên Kaggle
51
+ output_dir = "/kaggle/working/visualizations"
52
+ else:
53
+ output_dir = VISUALIZATION_PATH
54
+ os.makedirs(output_dir, exist_ok=True)
55
+
56
+ """Visualize word frequency for a given iteration"""
57
+ if not self.MODE_VISUALIZATION:
58
+ return
59
+
60
+ print(f"\nIteration {iteration} - Word Frequency (Top 5):")
61
+ top_words = word_freq.most_common(5)
62
+ for word, count in top_words:
63
+ print(f" {word}: {count}")
64
+
65
+ if not os.path.exists(output_dir):
66
+ os.makedirs(output_dir)
67
+
68
+ words, counts = zip(*word_freq.most_common(10)) if word_freq else ([], [])
69
+ if words:
70
+ plt.figure(figsize=(8, 6))
71
+ plt.bar(words, counts, color='purple', alpha=0.7)
72
+ plt.title(f'Word Frequency - Iteration {iteration}')
73
+ plt.xlabel('Words')
74
+ plt.ylabel('Frequency')
75
+ plt.xticks(rotation=45)
76
+ plt.grid(True, axis='y')
77
+ plt.savefig(os.path.join(output_dir, f'word_freq_iter_{iteration}.png'))
78
+ plt.close()
79
+
80
+ def get_probability(self, tokens):
81
+ """Calculate probability P(V) for a vietnamese tokens sequence"""
82
+ if not tokens or not self.lm:
83
+ return 0.0
84
+
85
+ start_tokens = ['<s>'] * (self.order - 1)
86
+ tokens = start_tokens + tokens
87
+ log_prob = 0.0
88
+
89
+ for i in range(self.order - 1, len(tokens)):
90
+ context = tokens[max(0, i - self.order + 1):i]
91
+ word = tokens[i]
92
+ prob = self.lm.score(word, context) or 1e-10
93
+ log_prob += math.log(prob)
94
+
95
+ return log_prob
96
+
97
+ def visualize_log_probabilities(self, sentences, max_sentences=100, output_dir="/kaggle/working/visualizations"):
98
+ if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
99
+ # Đang chạy trên Kaggle
100
+ output_dir = "/kaggle/working/visualizations"
101
+ else:
102
+ # Chạy local
103
+ output_dir = VISUALIZATION_PATH
104
+
105
+ os.makedirs(output_dir, exist_ok=True)
106
+ """Visualize the log probabilities of a sample of sentences"""
107
+ if not self.MODE_VISUALIZATION:
108
+ return
109
+
110
+ if not self.lm:
111
+ print("Cannot visualize log probabilities: Language model not trained.")
112
+ return
113
+
114
+ # Sample sentences to reduce computation
115
+ sample_size = min(len(sentences), max_sentences)
116
+ sample_sentences = random.sample(sentences, sample_size) if len(sentences) > max_sentences else sentences
117
+
118
+ # Compute log probabilities
119
+ log_probs = []
120
+ for sent in sample_sentences:
121
+ tokens = self.preprocess(sent)
122
+ log_prob = self.get_probability(tokens)
123
+ log_probs.append(log_prob)
124
+
125
+ # Print summary statistics
126
+ print(f"\nLog Probabilities for {len(log_probs)} sentences:")
127
+ print(f" Mean Log Probability: {np.mean(log_probs):.2f}")
128
+ print(f" Min Log Probability: {min(log_probs):.2f}")
129
+ print(f" Max Log Probability: {max(log_probs):.2f}")
130
+
131
+ # Plot histogram of log probabilities
132
+ if not os.path.exists(output_dir):
133
+ os.makedirs(output_dir)
134
+
135
+ plt.figure(figsize=(8, 6))
136
+ plt.hist(log_probs, bins=30, color='blue', alpha=0.7)
137
+ plt.title('Distribution of Log Probabilities for Sentences')
138
+ plt.xlabel('Log Probability')
139
+ plt.ylabel('Frequency')
140
+ plt.grid(True)
141
+ plt.savefig(os.path.join(output_dir, 'log_probabilities.png'))
142
+ plt.close()
143
+ print(f"Log probabilities visualization saved to {output_dir}/log_probabilities.png")
144
+
145
+ def train(self, vietnamese_sentences, max_sentences=200000):
146
+ """Training Language Model with memory optimization"""
147
+ print(f"Training Language Model on {min(len(vietnamese_sentences), max_sentences)} sentences...")
148
+
149
+ # Limit training data for LM to reduce memory
150
+ if len(vietnamese_sentences) > max_sentences:
151
+ print(f"Sampling {max_sentences} sentences from {len(vietnamese_sentences)} for LM training")
152
+ vietnamese_sentences = random.sample(vietnamese_sentences, max_sentences)
153
+
154
+ # Process in batches to reduce memory usage
155
+ all_tokens = []
156
+ batch_size = 10000
157
+ word_freq = Counter()
158
+ iteration = 0
159
+
160
+ for i in range(0, len(vietnamese_sentences), batch_size):
161
+ batch = vietnamese_sentences[i:i+batch_size]
162
+ batch_tokens = [self.preprocess(sent) for sent in batch]
163
+ all_tokens.extend(batch_tokens)
164
+
165
+ # Update word frequency for visualization
166
+ if self.MODE_VISUALIZATION and iteration < 2: # Limit to 2 iterations
167
+ for tokens in batch_tokens:
168
+ word_freq.update(tokens)
169
+ self.visualize_iterations(word_freq, iteration + 1, batch_tokens)
170
+ iteration += 1
171
+
172
+ # Force garbage collection
173
+ if i % (batch_size * 5) == 0:
174
+ gc.collect()
175
+
176
+ vocab = set()
177
+ for tokens in all_tokens:
178
+ vocab.update(tokens)
179
+
180
+ # Limit vocabulary size to most frequent words
181
+ if len(vocab) > LIMIT_VOCAB:
182
+ word_freq = Counter()
183
+ for tokens in all_tokens:
184
+ word_freq.update(tokens)
185
+
186
+ # Keep only top words
187
+ most_common = word_freq.most_common(LIMIT_VOCAB)
188
+ vocab = set(word for word, _ in most_common)
189
+ print(f"Limited vocabulary to {len(vocab)} most frequent words")
190
+
191
+ self.vocab_size = len(vocab)
192
+
193
+ # Filter sentences to contain only vocabulary words
194
+ filtered_sentences = []
195
+ for tokens in all_tokens:
196
+ filtered_tokens = [token for token in tokens if token in vocab]
197
+ if filtered_tokens: # Only add non-empty sentences
198
+ filtered_sentences.append(filtered_tokens)
199
+
200
+ # Clear original data
201
+ del all_tokens
202
+ gc.collect()
203
+
204
+ # Train N-gram model
205
+ train_data, padded_sents = padded_everygram_pipeline(self.order, filtered_sentences)
206
+ self.lm = MLE(self.order)
207
+ self.lm.fit(train_data, padded_sents)
208
+
209
+ # Visualize log probabilities after training
210
+ if self.MODE_VISUALIZATION:
211
+ self.visualize_log_probabilities(vietnamese_sentences)
212
+
213
+ # Clear training data
214
+ del filtered_sentences, train_data, padded_sents
215
+ gc.collect()
216
+
217
+ return {"vocab_size": self.vocab_size, "ngram_order": self.order}
218
+
219
+ ############################################# 2. Translation Model #############################################
220
+
221
+ class TranslationModel:
222
+ """Memory-optimized Translation Model"""
223
+ def __init__(self, max_phrase_length=MAX_PHRASE_LENGTH, MODE_VISUALIZATION=MODE_VISUALIZATION):
224
+ self.max_phrase_length = max_phrase_length
225
+ self.phrase_table = {}
226
+ self.word_alignments = []
227
+ self.MODE_VISUALIZATION = MODE_VISUALIZATION
228
+
229
+ def preprocess(self, text, lang):
230
+ """Preprocess text for both languages"""
231
+ text = text.lower()
232
+ if lang == 'eng':
233
+ text = contractions.fix(text)
234
+ return word_tokenize(text)
235
+ elif lang == 'vie':
236
+ return ViTokenizer.tokenize(text).split()
237
+ else:
238
+ return text.split()
239
+
240
+ def load_bilingual_data_batch(self, file_path, batch_size=BATCH_SIZE):
241
+ """Load bilingual data in batches to reduce memory usage"""
242
+ print(f"Loading bilingual data from {file_path} in batches")
243
+ # default = '/kaggle/input/general-data/bilingual_cleaned_dataset.csv'
244
+ try:
245
+ df = pd.read_csv(file_path)
246
+ except FileNotFoundError:
247
+ file_path = os.path.join('datatest', BILINGUAL_DATA_PATH)
248
+ df = pd.read_csv(file_path)
249
+ total_rows = len(df)
250
+ print(f"Total rows: {total_rows}")
251
+
252
+ for start_idx in range(0, total_rows, batch_size):
253
+ end_idx = min(start_idx + batch_size, total_rows)
254
+ batch_df = df.iloc[start_idx:end_idx]
255
+
256
+ aligned_sentences = []
257
+ for _, row in batch_df.iterrows():
258
+ eng_tokens = self.preprocess(row['en'], 'eng')
259
+ vie_tokens = self.preprocess(row['vi'], 'vie')
260
+
261
+ # Filter out very long sentences to save memory
262
+ if len(eng_tokens) <= 50 and len(vie_tokens) <= 50:
263
+ aligned_sentences.append(AlignedSent(eng_tokens, vie_tokens))
264
+
265
+ yield aligned_sentences
266
+
267
+ # Clean up batch
268
+ del batch_df, aligned_sentences
269
+ gc.collect()
270
+
271
+ def visualize_alignments(self, aligned_sentences, max_sentences=2, output_dir="/kaggle/working/visualizations"):
272
+ if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
273
+ # Đang chạy trên Kaggle
274
+ output_dir = "/kaggle/working/visualizations"
275
+ else:
276
+ # Chạy local
277
+ output_dir = VISUALIZATION_PATH
278
+
279
+ os.makedirs(output_dir, exist_ok=True)
280
+ """Visualize word alignments for a sample of sentence pairs"""
281
+ if not self.MODE_VISUALIZATION:
282
+ return
283
+
284
+ if not self.ibm_model:
285
+ print("Cannot visualize alignments: IBM Model 1 not trained.")
286
+ return
287
+
288
+ # Sample sentences to reduce computation
289
+ sample_size = min(len(aligned_sentences), max_sentences)
290
+ sample_sentences = random.sample(aligned_sentences, sample_size) if len(aligned_sentences) > max_sentences else aligned_sentences
291
+
292
+ if not os.path.exists(output_dir):
293
+ os.makedirs(output_dir)
294
+
295
+ for idx, sent in enumerate(sample_sentences):
296
+ src_words = sent.words # English
297
+ tgt_words = sent.mots # Vietnamese
298
+ alignment = sent.alignment
299
+
300
+ # Create alignment matrix
301
+ matrix = np.zeros((len(tgt_words), len(src_words)))
302
+ for src_idx, tgt_idx in alignment:
303
+ if tgt_idx is not None and src_idx < len(src_words) and tgt_idx < len(tgt_words):
304
+ matrix[tgt_idx, src_idx] = 1
305
+
306
+ # Plot alignment matrix
307
+ plt.figure(figsize=(8, 6))
308
+ plt.imshow(matrix, cmap='Blues', interpolation='nearest')
309
+ plt.title(f'Alignment Matrix - Sentence Pair {idx + 1}')
310
+ plt.xlabel('English Words')
311
+ plt.ylabel('Vietnamese Words')
312
+ plt.xticks(range(len(src_words)), src_words, rotation=45, ha='right')
313
+ plt.yticks(range(len(tgt_words)), tgt_words)
314
+ plt.tight_layout()
315
+ plt.savefig(os.path.join(output_dir, f'alignment_matrix_{idx + 1}.png'))
316
+ plt.close()
317
+
318
+ # Print alignment details
319
+ print(f"\nSentence Pair {idx + 1}:")
320
+ print(f" English: {' '.join(src_words)}")
321
+ print(f" Vietnamese: {' '.join(tgt_words)}")
322
+ print(f" Alignments: {[(src_words[src], tgt_words[tgt]) for src, tgt in alignment if tgt is not None]}")
323
+
324
+ print(f"Alignment visualizations saved to {output_dir}/")
325
+
326
+ def _extract_alignments_memory_efficient(self, aligned_sentences, ibm_model):
327
+ """Memory-efficient alignment extraction"""
328
+ alignments = []
329
+
330
+ # Process in smaller batches
331
+ batch_size = 5000
332
+ for i in range(0, len(aligned_sentences), batch_size):
333
+ batch_alignments = []
334
+ batch_sentences = aligned_sentences[i:i+batch_size]
335
+
336
+ for sent_pair in batch_sentences:
337
+ eng_tokens = sent_pair.words
338
+ vie_tokens = sent_pair.mots
339
+
340
+ # Only keep high-probability alignments
341
+ alignment = []
342
+ for eng_i, eng_word in enumerate(eng_tokens):
343
+ best_prob = 0
344
+ best_vie_i = -1
345
+
346
+ for vie_i, vie_word in enumerate(vie_tokens):
347
+ prob = ibm_model.translation_table.get(eng_word, {}).get(vie_word, 0)
348
+ if prob > best_prob:
349
+ best_prob = prob
350
+ best_vie_i = vie_i
351
+
352
+ # Only keep alignments above threshold
353
+ if best_prob > 0.01: # Increased threshold
354
+ alignment.append((eng_i, best_vie_i))
355
+
356
+ batch_alignments.append(alignment)
357
+
358
+ alignments.extend(batch_alignments)
359
+
360
+ # Periodic cleanup
361
+ if i % (batch_size * 10) == 0:
362
+ gc.collect()
363
+
364
+ return alignments
365
+
366
+ def extract_phrases_memory_efficient(self, aligned_sentences):
367
+ """Memory-efficient phrase extraction"""
368
+ print("Extracting phrase pairs with memory optimization...")
369
+
370
+ # Use smaller data structures
371
+ phrase_counts = defaultdict(lambda: defaultdict(int))
372
+
373
+ # Process in batches
374
+ batch_size = 5000
375
+ for i in range(0, len(aligned_sentences), batch_size):
376
+ batch_sentences = aligned_sentences[i:i+batch_size]
377
+ batch_alignments = self.word_alignments[i:i+batch_size]
378
+
379
+ for sent_pair, alignments in zip(batch_sentences, batch_alignments):
380
+ if not alignments: # Skip sentences with no alignments
381
+ continue
382
+
383
+ eng_tokens = sent_pair.words
384
+ vie_tokens = sent_pair.mots
385
+ alignment_set = set(alignments)
386
+
387
+ # Extract word-level translations first
388
+ for eng_i, vie_i in alignments:
389
+ if eng_i < len(eng_tokens) and vie_i < len(vie_tokens):
390
+ eng_word = eng_tokens[eng_i]
391
+ vie_word = vie_tokens[vie_i]
392
+ phrase_counts[eng_word][vie_word] += 1
393
+
394
+ # Extract short phrases only (max length 3 to save memory)
395
+ max_len = min(3, self.max_phrase_length)
396
+ consistent_phrases = self._extract_consistent_phrases(
397
+ eng_tokens, vie_tokens, alignment_set, max_len
398
+ )
399
+
400
+ for eng_phrase, vie_phrase in consistent_phrases:
401
+ phrase_counts[eng_phrase][vie_phrase] += 1
402
+
403
+ # Periodic cleanup
404
+ if i % (batch_size * 5) == 0:
405
+ gc.collect()
406
+ print(f"Processed {i+batch_size} sentences...")
407
+
408
+ # Calculate probabilities with higher threshold
409
+ self.phrase_table = {}
410
+ for eng_phrase, vie_phrases in phrase_counts.items():
411
+ total_count = sum(vie_phrases.values())
412
+ if total_count >= MIN_PHRASE_COUNT: # Higher threshold
413
+ # Keep only top 3 translations per phrase to save memory
414
+ sorted_phrases = sorted(vie_phrases.items(), key=lambda x: x[1], reverse=True)[:3]
415
+
416
+ filtered_phrases = {}
417
+ for vie_phrase, count in sorted_phrases:
418
+ if count >= MIN_PHRASE_COUNT:
419
+ filtered_phrases[vie_phrase] = count / total_count
420
+
421
+ if filtered_phrases:
422
+ self.phrase_table[eng_phrase] = filtered_phrases
423
+
424
+ print(f"Extracted {len(self.phrase_table)} phrase pairs (filtered)")
425
+ # Visualize phrase table if enabled
426
+ if self.MODE_VISUALIZATION:
427
+ self.visualize_phrase_table()
428
+
429
+ return self.phrase_table
430
+
431
+ def _extract_consistent_phrases(self, eng_tokens, vie_tokens, alignments, max_length):
432
+ """Extract consistent phrase pairs with length limit"""
433
+ consistent_phrases = []
434
+ eng_len = len(eng_tokens)
435
+
436
+ # Limit phrase extraction to reduce memory
437
+ for e_start in range(eng_len):
438
+ for e_end in range(e_start, min(eng_len, e_start + max_length)):
439
+ vie_positions = set()
440
+ for e_pos in range(e_start, e_end + 1):
441
+ for (eng_idx, vie_idx) in alignments:
442
+ if eng_idx == e_pos:
443
+ vie_positions.add(vie_idx)
444
+
445
+ if not vie_positions:
446
+ continue
447
+
448
+ v_start, v_end = min(vie_positions), max(vie_positions)
449
+
450
+ if v_end - v_start + 1 <= max_length:
451
+ if self._is_consistent_phrase_pair(e_start, e_end, v_start, v_end, alignments):
452
+ eng_phrase = ' '.join(eng_tokens[e_start:e_end+1])
453
+ vie_phrase = ' '.join(vie_tokens[v_start:v_end+1])
454
+ consistent_phrases.append((eng_phrase, vie_phrase))
455
+
456
+ return consistent_phrases
457
+
458
+ def _is_consistent_phrase_pair(self, e_start, e_end, v_start, v_end, alignments):
459
+ """Check if a phrase pair is consistent"""
460
+ for (eng_idx, vie_idx) in alignments:
461
+ if (e_start <= eng_idx <= e_end) and not (v_start <= vie_idx <= v_end):
462
+ return False
463
+ if (v_start <= vie_idx <= v_end) and not (e_start <= eng_idx <= e_end):
464
+ return False
465
+ return True
466
+
467
+ def train_ibm_model_incremental(self, file_path="/kaggle/input/general-data/bilingual_cleaned_dataset.csv", iterations=5):
468
+ """Train IBM Model 1 incrementally to reduce memory usage"""
469
+ if not os.path.exists(file_path):
470
+ file_path = os.path.join('datatest', BILINGUAL_DATA_PATH)
471
+ print(f"Training IBM Model 1 incrementally with {iterations} iterations...")
472
+
473
+ # First pass: collect vocabulary and create aligned sentences
474
+ all_aligned_sentences = []
475
+ eng_vocab = set()
476
+ vie_vocab = set()
477
+
478
+ for batch in self.load_bilingual_data_batch(file_path):
479
+ for sent_pair in batch:
480
+ eng_vocab.update(sent_pair.words)
481
+ vie_vocab.update(sent_pair.mots)
482
+ all_aligned_sentences.append(sent_pair)
483
+
484
+ # Limit total sentences to prevent memory issues
485
+ if len(all_aligned_sentences) >= 300000: # Reduced from 500k
486
+ print(f"Limited training to {len(all_aligned_sentences)} sentences")
487
+ break
488
+
489
+ print(f"Training on {len(all_aligned_sentences)} aligned sentences")
490
+ print(f"English vocab: {len(eng_vocab)}, Vietnamese vocab: {len(vie_vocab)}")
491
+
492
+ ibm_model = IBMModel1(all_aligned_sentences, iterations)
493
+
494
+ # Extract alignments with memory optimization
495
+ self.word_alignments = self._extract_alignments_memory_efficient(all_aligned_sentences, ibm_model)
496
+
497
+ # Clean up
498
+ del ibm_model
499
+ gc.collect()
500
+
501
+ return all_aligned_sentences
502
+
503
+ def visualize_phrase_table(self, max_phrases=10, output_dir="/kaggle/working/visualizations"):
504
+ if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
505
+ # Đang chạy trên Kaggle
506
+ output_dir = "/kaggle/working/visualizations"
507
+ else:
508
+ # Chạy local
509
+ output_dir = VISUALIZATION_PATH
510
+
511
+ os.makedirs(output_dir, exist_ok=True)
512
+ """Visualize the phrase table as a heatmap with English phrases as columns and Vietnamese phrases as rows"""
513
+ if not self.MODE_VISUALIZATION:
514
+ return
515
+
516
+ if not self.phrase_table:
517
+ print("Cannot visualize phrase table: Phrase table is empty.")
518
+ return
519
+
520
+ # Select top English phrases and their top Vietnamese translations
521
+ eng_phrases = sorted(self.phrase_table.keys(), key=lambda x: sum(self.phrase_table[x].values()), reverse=True)[:max_phrases]
522
+ vie_phrases = set()
523
+ for eng in eng_phrases:
524
+ vie_phrases.update(self.phrase_table[eng].keys())
525
+ vie_phrases = sorted(list(vie_phrases))[:max_phrases] # Limit Vietnamese phrases
526
+
527
+ # Create matrix for probabilities
528
+ matrix = np.zeros((len(vie_phrases), len(eng_phrases)))
529
+ for i, vie in enumerate(vie_phrases):
530
+ for j, eng in enumerate(eng_phrases):
531
+ matrix[i, j] = self.phrase_table.get(eng, {}).get(vie, 0)
532
+
533
+ # Create heatmap
534
+ if not os.path.exists(output_dir):
535
+ os.makedirs(output_dir)
536
+
537
+ plt.figure(figsize=(12, 8))
538
+ plt.imshow(matrix, cmap='Blues', interpolation='nearest')
539
+ plt.title('Phrase Table Translation Probabilities')
540
+ plt.xlabel('English Phrases')
541
+ plt.ylabel('Vietnamese Phrases')
542
+ plt.xticks(range(len(eng_phrases)), eng_phrases, rotation=45, ha='right')
543
+ plt.yticks(range(len(vie_phrases)), vie_phrases)
544
+ plt.colorbar(label='Translation Probability')
545
+ plt.tight_layout()
546
+ plt.savefig(os.path.join(output_dir, 'phrase_table.png'))
547
+ plt.close()
548
+
549
+ # Print sample phrase pairs
550
+ print("\nSample Phrase Table Entries (Top 5 English phrases):")
551
+ for eng in eng_phrases[:5]:
552
+ print(f" English: {eng}")
553
+ for vie, prob in sorted(self.phrase_table[eng].items(), key=lambda x: x[1], reverse=True)[:3]:
554
+ print(f" -> Vietnamese: {vie}, Probability: {prob:.4f}")
555
+
556
+ print(f"Phrase table visualization saved to {output_dir}/phrase_table.png")
557
+
558
+ ############################################# 3. Decoder Algorithm #############################################
559
+
560
+ class Decoder:
561
+ """Memory-optimized decoder"""
562
+ def __init__(self, phrase_table, language_model, beam_size=BEAM_SIZE):
563
+ self.phrase_table = phrase_table
564
+ self.lm = language_model
565
+ self.beam_size = beam_size
566
+ def translate(self, sentence):
567
+ """Translate sentence with memory optimization"""
568
+ tokens = sentence.lower().split()
569
+ if not tokens:
570
+ return ""
571
+ return self._greedy_translate(tokens)
572
+
573
+ def _greedy_translate(self, tokens):
574
+ """Greedy translation to save memory"""
575
+ translation = []
576
+ i = 0
577
+
578
+ while i < len(tokens):
579
+ best_phrase_len = 1
580
+ best_translation = tokens[i] # fallback
581
+
582
+ # Try phrases of different lengths
583
+ for phrase_len in range(min(3, len(tokens) - i), 0, -1): # Max length 3
584
+ eng_phrase = ' '.join(tokens[i:i+phrase_len])
585
+
586
+ if eng_phrase in self.phrase_table:
587
+ # Get best translation
588
+ vie_translations = self.phrase_table[eng_phrase]
589
+ if vie_translations:
590
+ best_vie_phrase = max(vie_translations.items(), key=lambda x: x[1])
591
+ best_translation = best_vie_phrase[0]
592
+ best_phrase_len = phrase_len
593
+ break
594
+
595
+ translation.append(best_translation)
596
+ i += best_phrase_len
597
+
598
+ return ' '.join(translation)
599
+
600
+ class Hypothesis:
601
+ """Lightweight hypothesis class"""
602
+ def __init__(self, translation, coverage, score, last_phrase_end):
603
+ self.translation = translation
604
+ self.coverage = coverage
605
+ self.score = score
606
+ self.last_phrase_end = last_phrase_end
607
+
608
+ ################################################# 4. Combine all SMT System #############################################
609
+ class SMT:
610
+ """Memory-optimized SMT system"""
611
+ def __init__(self):
612
+ self.lm = LanguageModel(order=LM_ORDER)
613
+ self.tm = TranslationModel(max_phrase_length=MAX_PHRASE_LENGTH)
614
+ self.decoder = None
615
+
616
+ def post_process(self, text):
617
+ """Replaces underscores with spaces in the translated text."""
618
+ return text.replace("_", " ")
619
+
620
+ def train(self):
621
+ bilingual_path = "/kaggle/input/general-data/bilingual_cleaned_dataset.csv"
622
+ vie_path = "/kaggle/input/general-data/vie_cleaned_dataset.csv"
623
+
624
+ if not os.path.exists(bilingual_path):
625
+ bilingual_path = os.path.join("datatest", BILINGUAL_DATA_PATH)
626
+ vie_path = os.path.join("datatest", VIE_DATA_PATH)
627
+ print("=== Training Translation Model ===")
628
+ aligned_sentences = self.tm.train_ibm_model_incremental(bilingual_path)
629
+ phrase_table = self.tm.extract_phrases_memory_efficient(aligned_sentences)
630
+
631
+ del aligned_sentences
632
+ gc.collect()
633
+
634
+ # Train language model
635
+ print("\n=== Training Language Model ===")
636
+ vie_df = pd.read_csv(vie_path)
637
+ vietnamese_sentences = vie_df['vi'].tolist()
638
+ del vie_df # Free memory
639
+ gc.collect()
640
+
641
+ lm_stats = self.lm.train(vietnamese_sentences, max_sentences=50000) # Limit LM training data
642
+ del vietnamese_sentences # Free memory
643
+ gc.collect()
644
+
645
+ # Initialize decoder
646
+ self.decoder = Decoder(phrase_table, self.lm)
647
+
648
+ # Save model immediately
649
+ self.save_model()
650
+
651
+ return {
652
+ "phrase_pairs": len(phrase_table),
653
+ "lm_stats": lm_stats
654
+ }
655
+
656
+ def translate_sentence(self, sentence):
657
+ """Translate a single sentence"""
658
+ if self.decoder is None:
659
+ raise ValueError("Model not trained or loaded.")
660
+ translated_text_with_underscores = self.decoder.translate(sentence)
661
+ return self.post_process(translated_text_with_underscores)
662
+
663
+ def save_model(self):
664
+ """Save the trained model"""
665
+ if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
666
+ # Đang chạy trên Kaggle
667
+ model_dir = "/kaggle/working/checkpoints"
668
+ else:
669
+ # Chạy local
670
+ model_dir = "checkpoints"
671
+
672
+ os.makedirs(model_dir, exist_ok=True)
673
+
674
+ # Save with compression
675
+ with open(os.path.join(model_dir, "phrase_table.pkl"), 'wb') as f:
676
+ pickle.dump(self.tm.phrase_table, f, protocol=pickle.HIGHEST_PROTOCOL)
677
+ with open(os.path.join(model_dir, "lm_object.pkl"), 'wb') as f:
678
+ pickle.dump(self.lm, f, protocol=pickle.HIGHEST_PROTOCOL)
679
+
680
+ print(f"Model saved to {model_dir}")
681
+
682
+ def load_model(self, model_dir='checkpoints'):
683
+ """Load a pre-trained model"""
684
+ with open(os.path.join(model_dir, "phrase_table.pkl"), 'rb') as f:
685
+ phrase_table = pickle.load(f)
686
+ with open(os.path.join(model_dir, "lm_object.pkl"), 'rb') as f:
687
+ self.lm = pickle.load(f)
688
+
689
+ self.decoder = Decoder(phrase_table, self.lm, BEAM_SIZE)
690
+ self.tm.phrase_table = phrase_table
691
+
692
+ print(f"Model loaded from {model_dir}")
693
+
694
+ def evaluate(self, test_file='/kaggle/input/general-data/test_cleaned_dataset.csv', sample_size=5):
695
+ """Evaluate model on test set"""
696
+ try :
697
+ df = pd.read_csv(test_file)
698
+ except FileNotFoundError:
699
+ test_file = 'datatest/test_cleaned_dataset.csv'
700
+ df = pd.read_csv(test_file)
701
+ sample_size = min(sample_size, len(df))
702
+ sample_indices = random.sample(range(len(df)), sample_size)
703
+
704
+ results = []
705
+ for idx in sample_indices:
706
+ try:
707
+ source = df.iloc[idx]['en']
708
+ reference = df.iloc[idx]['vi']
709
+ translation = self.translate_sentence(source)
710
+
711
+ results.append({
712
+ "source": source,
713
+ "reference": reference,
714
+ "translation": translation
715
+ })
716
+ except Exception as e:
717
+ print(f"Error translating sentence {idx}: {e}")
718
+ results.append({
719
+ "source": df.iloc[idx]['en'],
720
+ "reference": df.iloc[idx]['vi'],
721
+ "translation": "Translation failed"
722
+ })
723
+
724
+ return results
725
+
726
+ def save_predictions_batch(self, test_file="/kaggle/input/general-data/test_cleaned_dataset.csv", output_file="/kaggle/working/predicted.csv", batch_size=1000):
727
+ """Save predictions in batches to avoid memory issues"""
728
+ # Check if test_file exists, if not update to default path
729
+ if not os.path.exists(test_file):
730
+ test_file = "datatest/test_cleaned_dataset.csv"
731
+ output_file = "datatest/predicted1.csv"
732
+ print(f"Output file will be saved to: {output_file}")
733
+
734
+ df_info = pd.read_csv(test_file, nrows=0) # Just get column info
735
+ total_rows = len(pd.read_csv(test_file))
736
+
737
+ print(f"Processing {total_rows} sentences in batches of {batch_size}")
738
+
739
+ # Process in batches and write incrementally
740
+ first_batch = True
741
+ for start_idx in tqdm(range(0, total_rows, batch_size), desc="Processing batches"):
742
+ end_idx = min(start_idx + batch_size, total_rows)
743
+
744
+ # Read batch
745
+ batch_df = pd.read_csv(test_file, skiprows=range(1, start_idx+1), nrows=batch_size)
746
+
747
+ # Process batch
748
+ batch_predictions = []
749
+ for _, row in batch_df.iterrows():
750
+ try:
751
+ source = row['en']
752
+ reference = row['vi']
753
+ translation = self.translate_sentence(source)
754
+
755
+ batch_predictions.append({
756
+ "en": source,
757
+ "vi": reference,
758
+ "pre": translation
759
+ })
760
+ except Exception as e:
761
+ batch_predictions.append({
762
+ "en": row['en'],
763
+ "vi": row['vi'],
764
+ "pre": "Translation failed"
765
+ })
766
+
767
+ # Save batch
768
+ batch_pred_df = pd.DataFrame(batch_predictions)
769
+
770
+ if first_batch:
771
+ batch_pred_df.to_csv(output_file, index=False)
772
+ first_batch = False
773
+ else:
774
+ batch_pred_df.to_csv(output_file, mode='a', header=False, index=False)
775
+
776
+ # Clean up
777
+ del batch_df, batch_predictions, batch_pred_df
778
+ gc.collect()
779
+
780
+ print(f"Predictions saved to {output_file}")
781
+ return output_file
782
+
783
+ def main():
784
+ print("Starting Memory-Optimized SMT System...")
785
+ smt = SMT()
786
+ model_dir = "checkpoints"
787
+ if os.path.exists(model_dir) and os.path.isfile(os.path.join(model_dir, "phrase_table.pkl")):
788
+ print("Loading existing model...")
789
+ smt.load_model()
790
+ else:
791
+ print("Training new model...")
792
+ stats = smt.train()
793
+ print(f"Training complete: {stats}")
794
+
795
+ # Evaluate model
796
+ print("\nEvaluating model...")
797
+ results = smt.evaluate(sample_size=1)
798
+ print("\nExample translations:")
799
+ for i, result in enumerate(results):
800
+ print(f"\nExample {i+1}:")
801
+ print(f"English: {result['source']}")
802
+ print(f"Reference: {result['reference']}")
803
+ print(f"Translation: {result['translation']}")
804
+
805
+ # Save predictions in batches
806
+ print("\nSaving predictions in batches...")
807
+ output_file = smt.save_predictions_batch(batch_size=500) # Smaller batch size
808
+ print(f"All predictions saved to: {output_file}")
809
+
810
+ # Final memory cleanup
811
+ gc.collect()
812
+ print("Processing complete!")
813
+
814
+ class SMTExtended(SMT):
815
+ def infer(self, sentence):
816
+ """Translate a single arbitrary English sentence into Vietnamese using beam search"""
817
+ if self.decoder is None:
818
+ raise ValueError("Model not trained or loaded.")
819
+
820
+ # Preprocess input sentence
821
+ tokens = self.tm.preprocess(sentence, 'eng')
822
+ if not tokens:
823
+ return ""
824
+
825
+ # Initialize beam: (score, translation_tokens, last_pos, covered_positions)
826
+ beam = [(0.0, [], 0, set())] # Score, translation tokens, last position, covered positions
827
+ best_score = float('-inf')
828
+ best_translation = []
829
+
830
+ # Beam search
831
+ while beam:
832
+ new_beam = []
833
+ for score, trans_tokens, last_pos, covered in beam:
834
+ # Check if all positions are covered
835
+ if len(covered) == len(tokens):
836
+ if score > best_score:
837
+ best_score = score
838
+ best_translation = trans_tokens
839
+ continue
840
+
841
+ # Find next uncovered position
842
+ next_pos = last_pos
843
+ while next_pos in covered and next_pos < len(tokens):
844
+ next_pos += 1
845
+
846
+ if next_pos >= len(tokens):
847
+ if score > best_score:
848
+ best_score = score
849
+ best_translation = trans_tokens
850
+ continue
851
+
852
+ # Try phrases starting at next_pos
853
+ for phrase_len in range(1, min(self.tm.max_phrase_length + 1, len(tokens) - next_pos + 1)):
854
+ eng_phrase = ' '.join(tokens[next_pos:next_pos + phrase_len])
855
+
856
+ # Get possible translations from phrase table
857
+ vie_translations = self.tm.phrase_table.get(eng_phrase, {})
858
+ if not vie_translations and phrase_len == 1:
859
+ # Fallback for single unknown word
860
+ vie_translations = {tokens[next_pos]: 1.0}
861
+
862
+ for vie_phrase, trans_prob in vie_translations.items():
863
+ # Split Vietnamese phrase into tokens for LM scoring
864
+ vie_tokens = vie_phrase.split()
865
+ # Calculate new score: combine translation prob and LM prob
866
+ log_trans_prob = math.log(trans_prob) if trans_prob > 0 else math.log(1e-10)
867
+ lm_score = self.lm.get_probability(trans_tokens + vie_tokens)
868
+ new_score = ALPHA * log_trans_prob + BETA * lm_score
869
+
870
+ # Update covered positions
871
+ new_covered = covered | set(range(next_pos, next_pos + phrase_len))
872
+ # Add to new beam
873
+ new_beam.append((score + new_score, trans_tokens + vie_tokens, next_pos + phrase_len, new_covered))
874
+
875
+ # Keep top BEAM_SIZE hypotheses
876
+ new_beam.sort(key=lambda x: x[0], reverse=True)
877
+ beam = new_beam[:self.decoder.beam_size]
878
+
879
+ # Return best translation
880
+ return ' '.join(best_translation) if best_translation else "Translation failed"
881
+
882
+ if __name__ == "__main__":
883
+ main()
884
+