TymaaHammouda commited on
Commit
244b2f5
·
verified ·
1 Parent(s): e3cbe7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -3
app.py CHANGED
@@ -2,6 +2,167 @@ from fastapi import FastAPI
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
+ import json
6
+ from torch.utils.data.dataloader import DataLoader
7
+ import pandas as pd
8
+ import torch
9
+ from pydantic import BaseModel
10
+ from fastapi.responses import JSONResponse
11
+ import os
12
+
13
+ from transformers import (
14
+ AutoConfig,
15
+ AutoModelForSequenceClassification,
16
+ AutoTokenizer,
17
+ default_data_collator,
18
+ set_seed
19
+ )
20
+
21
+
22
+
23
+ DATA_DIR = r"./bank_intenet_model"
24
+ model_name_or_path = os.path.join(DATA_DIR, "model")
25
+ ar_id_to_label_file = os.path.join(DATA_DIR, "id_to_label_ar.json")
26
+ en_id_to_label_file = os.path.join(DATA_DIR, "id_to_label_en.json")
27
+ id_to_label_files ={
28
+ 'ar' : ar_id_to_label_file,
29
+ 'en' :en_id_to_label_file,
30
+ }
31
+ seed = 42
32
+ max_length=128
33
+ per_device_eval_batch_size = 16
34
+ use_slow_tokenizer = True
35
+ pad_to_max_length =True
36
+
37
+
38
+ def load_id_to_label(lang):
39
+ """
40
+ Loads a JSON file containing a dictionary of integer IDs mapped to string labels and returns a dictionary with integer
41
+ keys and string values.
42
+
43
+ Args:
44
+ - lang (str): the selected language
45
+ Returns:
46
+ - id_to_label (dict): a dictionary with integer keys and string values based on the selected langauge.
47
+ """
48
+
49
+
50
+ json_file_path = id_to_label_files[lang]
51
+
52
+ with open(json_file_path, "r", encoding="utf-8") as f:
53
+ content_dict = json.load(f)
54
+ return {int(key): value for key, value in content_dict.items()}
55
+
56
+
57
+
58
+ def pridct(my_list, lang):
59
+ """
60
+ Makes predictions on a list of texts using a Hugging Face model.
61
+
62
+ Args:
63
+ - my_list (list): a list of texts to make predictions on
64
+ -lang (string): a language to select from for the predicated labels.
65
+
66
+ Returns:
67
+ - output_dict (dict): a dictionary containing the predicted labels, the corresponding texts, and the prediction
68
+ probabilities
69
+ """
70
+ padding = "max_length" if pad_to_max_length else False
71
+ tokenized_texts = tokenizer(my_list,
72
+ padding=padding,
73
+ truncation=True ,
74
+ add_special_tokens =True,
75
+ max_length=max_length,
76
+ return_tensors='pt')
77
+
78
+
79
+ print("tokenized_texts : ",tokenized_texts)
80
+ all_predictions = []
81
+ all_probs = []
82
+ all_texts = []
83
+ model.eval()
84
+ pt_inputs = {k: torch.tensor(v) for k, v in tokenized_texts.items()}
85
+ with torch.no_grad():
86
+ outputs = model(**pt_inputs)
87
+
88
+ outputs.logits.cpu().numpy()
89
+
90
+
91
+ logits = outputs.logits
92
+ print('logits',logits )
93
+ predictions = outputs.logits.argmax(dim=-1)
94
+ softmax_outputs = torch.nn.functional.softmax(outputs.logits, dim=1)
95
+ all_predictions.extend(predictions.cpu().numpy().tolist())
96
+ all_probs.extend(softmax_outputs.detach().cpu().numpy().tolist())
97
+ #all_texts.extend([settings.tokenizer.decode(inp, skip_special_tokens=True) for inp in pt_inputs])
98
+
99
+ id_to_label = load_id_to_label(lang)
100
+ labeled_predictions = [id_to_label[pred] for pred in all_predictions]
101
+ all_probs = [float(f"{max(prob):.3f}") for prob in all_probs]
102
+ df = pd.DataFrame({"text": my_list, "predicted_ids":all_predictions, "predicted_label": labeled_predictions, "prob_value": all_probs})
103
+
104
+ return df.to_dict()
105
+
106
+ def remove_empty_values(sentences):
107
+ return [value for value in sentences if value != '']
108
+
109
+
110
+ def sent_tokenize(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
111
+ separators = []
112
+ split_text = [text]
113
+ if new_line==True:
114
+ separators.append('\n')
115
+ if dot==True:
116
+ separators.append('. ')
117
+ if question_mark==True:
118
+ separators.append('?')
119
+ separators.append('؟')
120
+ if exclamation_mark==True:
121
+ separators.append('!')
122
+
123
+ for sep in separators:
124
+ new_split_text = []
125
+ for part in split_text:
126
+ tokens = part.split(sep)
127
+ tokens_with_separator = [token + sep for token in tokens[:-1]]
128
+ tokens_with_separator.append(tokens[-1].strip())
129
+ new_split_text.extend(tokens_with_separator)
130
+ split_text = new_split_text
131
+
132
+ split_text = remove_empty_values(split_text)
133
+ return split_text
134
+
135
+
136
+ config = AutoConfig.from_pretrained(model_name_or_path )
137
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=not use_slow_tokenizer)
138
+ model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path,
139
+ from_tf=bool(".ckpt" in model_name_or_path ),
140
+ config=config)
141
+
142
+
143
+ class BankRequest(BaseModel):
144
+ lang: str
145
+ text: str
146
+
147
+ @app.post("/predict")
148
+ def predict(request: BankRequest):
149
+ # Load tagger
150
+ lang = request.lang
151
+ text = request.text
152
+
153
+ sentences = sent_tokenize(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True)
154
+ results = []
155
+
156
+ sentence = sentences[0]
157
+
158
+ result = pridct(sentence, lang)
159
+ results.append(result)
160
+
161
+ content = {"resp": results, "statusText": "OK", "statusCode": 0}
162
+
163
+ return JSONResponse(
164
+ content=content,
165
+ media_type="application/json",
166
+ status_code=200,
167
+ )
168
+