AHAAM commited on
Commit
5e25f9c
·
1 Parent(s): 62b019e

add models dropdown

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +46 -4
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -428,10 +428,25 @@ from svgpathtools import parse_path
428
  # ======================
429
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
430
 
 
 
 
 
 
 
 
431
  # ======================
432
- # Base multi-dialect model (B2BERT)
433
  # ======================
434
- base_model_name = "Mohamedelzeftawy/b2bert_baseline"
 
 
 
 
 
 
 
 
435
  base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
436
  base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
437
 
@@ -488,6 +503,24 @@ SVG_PATH = Path("assets/world-map.svg")
488
  SVG_NS = "http://www.w3.org/2000/svg"
489
  ET.register_namespace("", SVG_NS)
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  def _merge_style(old_style: str, updates: dict) -> str:
492
  """
493
  Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
@@ -693,13 +726,15 @@ def predict_dialects_with_confidence(text, threshold=0.3):
693
  return df
694
 
695
 
696
- def predict_wrapper(text, threshold):
 
697
  """
698
  Returns:
699
  df (table),
700
  summary (markdown),
701
  map_html (HTML)
702
  """
 
703
  df = predict_dialects_with_confidence(text, threshold)
704
 
705
  predicted_dialects = df[df["Prediction"] == "✓ Valid"]["Dialect"].tolist()
@@ -775,6 +810,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
775
 
776
  with gr.Row():
777
  with gr.Column(scale=1):
 
 
 
 
 
 
 
778
  text_input = gr.Textbox(
779
  label="Arabic Text Input",
780
  placeholder="أدخل نصًا عربيًا هنا... مثال: شلونك؟ / إزيك يا عم؟ / شو أخبارك؟",
@@ -870,7 +912,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
870
 
871
  predict_button.click(
872
  fn=predict_wrapper,
873
- inputs=[text_input, threshold_slider],
874
  outputs=[results_output, summary_output, map_output],
875
  )
876
 
 
428
  # ======================
429
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
430
 
431
+ # # ======================
432
+ # # Base multi-dialect model (B2BERT)
433
+ # # ======================
434
+ # base_model_name = "Mohamedelzeftawy/b2bert_baseline"
435
+ # base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
436
+ # base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
437
+
438
  # ======================
439
+ # Multi-dialect model registry
440
  # ======================
441
+ MODEL_CHOICES = {
442
+ "LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current)
443
+ "LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi",
444
+ "LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty",
445
+ }
446
+
447
+ # Load default model at startup (LahjatBERT)
448
+ _current_model_key = "LahjatBERT"
449
+ base_model_name = MODEL_CHOICES[_current_model_key]
450
  base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
451
  base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
452
 
 
503
  SVG_NS = "http://www.w3.org/2000/svg"
504
  ET.register_namespace("", SVG_NS)
505
 
506
+ def load_multidialect_model(model_key: str):
507
+ """
508
+ Load the selected multi-dialect model + tokenizer.
509
+ Uses global variables so the rest of your pipeline stays unchanged.
510
+ """
511
+ global base_model, base_tokenizer, base_model_name, _current_model_key
512
+
513
+ if model_key == _current_model_key:
514
+ return # already loaded
515
+
516
+ repo = MODEL_CHOICES[model_key]
517
+ base_model_name = repo
518
+
519
+ base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE)
520
+ base_tokenizer = AutoTokenizer.from_pretrained(repo)
521
+
522
+ _current_model_key = model_key
523
+
524
  def _merge_style(old_style: str, updates: dict) -> str:
525
  """
526
  Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
 
726
  return df
727
 
728
 
729
+
730
+ def predict_wrapper(model_key, text, threshold):
731
  """
732
  Returns:
733
  df (table),
734
  summary (markdown),
735
  map_html (HTML)
736
  """
737
+ load_multidialect_model(model_key)
738
  df = predict_dialects_with_confidence(text, threshold)
739
 
740
  predicted_dialects = df[df["Prediction"] == "✓ Valid"]["Dialect"].tolist()
 
810
 
811
  with gr.Row():
812
  with gr.Column(scale=1):
813
+ model_dropdown = gr.Dropdown(
814
+ choices=list(MODEL_CHOICES.keys()),
815
+ value="LahjatBERT",
816
+ label="Model",
817
+ info="Select which LahjatBERT variant to use for prediction."
818
+ )
819
+
820
  text_input = gr.Textbox(
821
  label="Arabic Text Input",
822
  placeholder="أدخل نصًا عربيًا هنا... مثال: شلونك؟ / إزيك يا عم؟ / شو أخبارك؟",
 
912
 
913
  predict_button.click(
914
  fn=predict_wrapper,
915
+ inputs=[model_dropdown, text_input, threshold_slider],
916
  outputs=[results_output, summary_output, map_output],
917
  )
918