BankIntent-api / app.py
TymaaHammouda's picture
Update app.py
244b2f5 verified
from fastapi import FastAPI
app = FastAPI()
import json
from torch.utils.data.dataloader import DataLoader
import pandas as pd
import torch
from pydantic import BaseModel
from fastapi.responses import JSONResponse
import os
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
default_data_collator,
set_seed
)
DATA_DIR = r"./bank_intenet_model"
model_name_or_path = os.path.join(DATA_DIR, "model")
ar_id_to_label_file = os.path.join(DATA_DIR, "id_to_label_ar.json")
en_id_to_label_file = os.path.join(DATA_DIR, "id_to_label_en.json")
id_to_label_files ={
'ar' : ar_id_to_label_file,
'en' :en_id_to_label_file,
}
seed = 42
max_length=128
per_device_eval_batch_size = 16
use_slow_tokenizer = True
pad_to_max_length =True
def load_id_to_label(lang):
"""
Loads a JSON file containing a dictionary of integer IDs mapped to string labels and returns a dictionary with integer
keys and string values.
Args:
- lang (str): the selected language
Returns:
- id_to_label (dict): a dictionary with integer keys and string values based on the selected langauge.
"""
json_file_path = id_to_label_files[lang]
with open(json_file_path, "r", encoding="utf-8") as f:
content_dict = json.load(f)
return {int(key): value for key, value in content_dict.items()}
def pridct(my_list, lang):
"""
Makes predictions on a list of texts using a Hugging Face model.
Args:
- my_list (list): a list of texts to make predictions on
-lang (string): a language to select from for the predicated labels.
Returns:
- output_dict (dict): a dictionary containing the predicted labels, the corresponding texts, and the prediction
probabilities
"""
padding = "max_length" if pad_to_max_length else False
tokenized_texts = tokenizer(my_list,
padding=padding,
truncation=True ,
add_special_tokens =True,
max_length=max_length,
return_tensors='pt')
print("tokenized_texts : ",tokenized_texts)
all_predictions = []
all_probs = []
all_texts = []
model.eval()
pt_inputs = {k: torch.tensor(v) for k, v in tokenized_texts.items()}
with torch.no_grad():
outputs = model(**pt_inputs)
outputs.logits.cpu().numpy()
logits = outputs.logits
print('logits',logits )
predictions = outputs.logits.argmax(dim=-1)
softmax_outputs = torch.nn.functional.softmax(outputs.logits, dim=1)
all_predictions.extend(predictions.cpu().numpy().tolist())
all_probs.extend(softmax_outputs.detach().cpu().numpy().tolist())
#all_texts.extend([settings.tokenizer.decode(inp, skip_special_tokens=True) for inp in pt_inputs])
id_to_label = load_id_to_label(lang)
labeled_predictions = [id_to_label[pred] for pred in all_predictions]
all_probs = [float(f"{max(prob):.3f}") for prob in all_probs]
df = pd.DataFrame({"text": my_list, "predicted_ids":all_predictions, "predicted_label": labeled_predictions, "prob_value": all_probs})
return df.to_dict()
def remove_empty_values(sentences):
return [value for value in sentences if value != '']
def sent_tokenize(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
separators = []
split_text = [text]
if new_line==True:
separators.append('\n')
if dot==True:
separators.append('. ')
if question_mark==True:
separators.append('?')
separators.append('؟')
if exclamation_mark==True:
separators.append('!')
for sep in separators:
new_split_text = []
for part in split_text:
tokens = part.split(sep)
tokens_with_separator = [token + sep for token in tokens[:-1]]
tokens_with_separator.append(tokens[-1].strip())
new_split_text.extend(tokens_with_separator)
split_text = new_split_text
split_text = remove_empty_values(split_text)
return split_text
config = AutoConfig.from_pretrained(model_name_or_path )
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=not use_slow_tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path,
from_tf=bool(".ckpt" in model_name_or_path ),
config=config)
class BankRequest(BaseModel):
lang: str
text: str
@app.post("/predict")
def predict(request: BankRequest):
# Load tagger
lang = request.lang
text = request.text
sentences = sent_tokenize(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True)
results = []
sentence = sentences[0]
result = pridct(sentence, lang)
results.append(result)
content = {"resp": results, "statusText": "OK", "statusCode": 0}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)