Spaces:
Running
Running
add models dropdown
Browse files
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
app.py
CHANGED
|
@@ -428,10 +428,25 @@ from svgpathtools import parse_path
|
|
| 428 |
# ======================
|
| 429 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
# ======================
|
| 432 |
-
#
|
| 433 |
# ======================
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
|
| 436 |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 437 |
|
|
@@ -488,6 +503,24 @@ SVG_PATH = Path("assets/world-map.svg")
|
|
| 488 |
SVG_NS = "http://www.w3.org/2000/svg"
|
| 489 |
ET.register_namespace("", SVG_NS)
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
def _merge_style(old_style: str, updates: dict) -> str:
|
| 492 |
"""
|
| 493 |
Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
|
|
@@ -693,13 +726,15 @@ def predict_dialects_with_confidence(text, threshold=0.3):
|
|
| 693 |
return df
|
| 694 |
|
| 695 |
|
| 696 |
-
|
|
|
|
| 697 |
"""
|
| 698 |
Returns:
|
| 699 |
df (table),
|
| 700 |
summary (markdown),
|
| 701 |
map_html (HTML)
|
| 702 |
"""
|
|
|
|
| 703 |
df = predict_dialects_with_confidence(text, threshold)
|
| 704 |
|
| 705 |
predicted_dialects = df[df["Prediction"] == "✓ Valid"]["Dialect"].tolist()
|
|
@@ -775,6 +810,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 775 |
|
| 776 |
with gr.Row():
|
| 777 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
text_input = gr.Textbox(
|
| 779 |
label="Arabic Text Input",
|
| 780 |
placeholder="أدخل نصًا عربيًا هنا... مثال: شلونك؟ / إزيك يا عم؟ / شو أخبارك؟",
|
|
@@ -870,7 +912,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 870 |
|
| 871 |
predict_button.click(
|
| 872 |
fn=predict_wrapper,
|
| 873 |
-
inputs=[text_input, threshold_slider],
|
| 874 |
outputs=[results_output, summary_output, map_output],
|
| 875 |
)
|
| 876 |
|
|
|
|
| 428 |
# ======================
|
| 429 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 430 |
|
| 431 |
+
# # ======================
|
| 432 |
+
# # Base multi-dialect model (B2BERT)
|
| 433 |
+
# # ======================
|
| 434 |
+
# base_model_name = "Mohamedelzeftawy/b2bert_baseline"
|
| 435 |
+
# base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
|
| 436 |
+
# base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 437 |
+
|
| 438 |
# ======================
|
| 439 |
+
# Multi-dialect model registry
|
| 440 |
# ======================
|
| 441 |
+
MODEL_CHOICES = {
|
| 442 |
+
"LahjatBERT": "Mohamedelzeftawy/b2bert_baseline", # default (current)
|
| 443 |
+
"LahjatBERT-CL-ALDI": "Mohamedelzeftawy/b2bert_cl_aldi",
|
| 444 |
+
"LahjatBERT-CL-Cardinality": "Mohamedelzeftawy/b2bert_cl_cardinalty",
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
# Load default model at startup (LahjatBERT)
|
| 448 |
+
_current_model_key = "LahjatBERT"
|
| 449 |
+
base_model_name = MODEL_CHOICES[_current_model_key]
|
| 450 |
base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name).to(DEVICE)
|
| 451 |
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
| 452 |
|
|
|
|
| 503 |
SVG_NS = "http://www.w3.org/2000/svg"
|
| 504 |
ET.register_namespace("", SVG_NS)
|
| 505 |
|
| 506 |
+
def load_multidialect_model(model_key: str):
|
| 507 |
+
"""
|
| 508 |
+
Load the selected multi-dialect model + tokenizer.
|
| 509 |
+
Uses global variables so the rest of your pipeline stays unchanged.
|
| 510 |
+
"""
|
| 511 |
+
global base_model, base_tokenizer, base_model_name, _current_model_key
|
| 512 |
+
|
| 513 |
+
if model_key == _current_model_key:
|
| 514 |
+
return # already loaded
|
| 515 |
+
|
| 516 |
+
repo = MODEL_CHOICES[model_key]
|
| 517 |
+
base_model_name = repo
|
| 518 |
+
|
| 519 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(repo).to(DEVICE)
|
| 520 |
+
base_tokenizer = AutoTokenizer.from_pretrained(repo)
|
| 521 |
+
|
| 522 |
+
_current_model_key = model_key
|
| 523 |
+
|
| 524 |
def _merge_style(old_style: str, updates: dict) -> str:
|
| 525 |
"""
|
| 526 |
Merge CSS style strings (e.g., "fill:#000;stroke:#fff") with updates dict.
|
|
|
|
| 726 |
return df
|
| 727 |
|
| 728 |
|
| 729 |
+
|
| 730 |
+
def predict_wrapper(model_key, text, threshold):
|
| 731 |
"""
|
| 732 |
Returns:
|
| 733 |
df (table),
|
| 734 |
summary (markdown),
|
| 735 |
map_html (HTML)
|
| 736 |
"""
|
| 737 |
+
load_multidialect_model(model_key)
|
| 738 |
df = predict_dialects_with_confidence(text, threshold)
|
| 739 |
|
| 740 |
predicted_dialects = df[df["Prediction"] == "✓ Valid"]["Dialect"].tolist()
|
|
|
|
| 810 |
|
| 811 |
with gr.Row():
|
| 812 |
with gr.Column(scale=1):
|
| 813 |
+
model_dropdown = gr.Dropdown(
|
| 814 |
+
choices=list(MODEL_CHOICES.keys()),
|
| 815 |
+
value="LahjatBERT",
|
| 816 |
+
label="Model",
|
| 817 |
+
info="Select which LahjatBERT variant to use for prediction."
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
text_input = gr.Textbox(
|
| 821 |
label="Arabic Text Input",
|
| 822 |
placeholder="أدخل نصًا عربيًا هنا... مثال: شلونك؟ / إزيك يا عم؟ / شو أخبارك؟",
|
|
|
|
| 912 |
|
| 913 |
predict_button.click(
|
| 914 |
fn=predict_wrapper,
|
| 915 |
+
inputs=[model_dropdown, text_input, threshold_slider],
|
| 916 |
outputs=[results_output, summary_output, map_output],
|
| 917 |
)
|
| 918 |
|