Spaces:
Sleeping
Sleeping
File size: 5,693 Bytes
edf14f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time, sys
from os import path, listdir
from pywhispercpp.model import Model
class AddressExtractor():
def __init__(self):
model_id = "microsoft/bitnet-b1.58-2B-4T"
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.bitnet_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map = "cpu",
)
# Set pad_token_id to eos_token_id
self.tokenizer.pad_token = self.tokenizer.eos_token
self.bitnet_model.config.pad_token_id = self.tokenizer.pad_token_id
self.whisper_model = Model('small.en-q5_1', n_threads = 16, language = 'en')
# self.whisper_model = Model('small.en', n_threads = 16, language = 'en')
# self.whisper_model = Model('tiny.en', n_threads = 16, language = 'en')
self.system_prompt_speech = """
Your task is to extract the US address given the ASR inferred text (using whisper-large-v3-turbo model) without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
"""
self.system_prompt_text = """
Your task is to extract the US address given the input text without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
"""
# self.sample_files_path = "./one_sentence_us_address/"
def compute_latency(self, start_time, end_time):
tr_duration= end_time-start_time
hours = tr_duration // 3600
minutes = (tr_duration - (hours * 3600)) // 60
seconds = tr_duration - ((hours * 3600) + (minutes * 60))
msg = f'inference elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds'
return msg
def infer_text_sample(self, input_text):
messages = [
{"role": "system", "content": self.system_prompt_text},
{"role": "user", "content": input_text},
]
if input_text.lower().strip() != "":
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)
# Generate response
chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part
if generated_text.strip() != "":
print("\n\n", "="*100)
print("Address Extracted: ", generated_text)
print("="*100, "\n\n")
def preprocess_text(self, input_text):
### Preprocessing the ASR generated text
input_tokens = []
for word in input_text.split(" "):
word = word.strip()
if word != "":
if "," in word:
try:
num = int(word)
word = word.replace(",", " ")
except:
word = word.replace(",", ", ")
input_tokens.append(word)
input_text = " ".join(input_tokens)
return input_text
def infer_audio_sample(self, audio_input_file_path):
input_text = ""
segments = self.whisper_model.transcribe(audio_input_file_path)
for segment in segments:
input_text += segment.text.strip()
input_text = self.preprocess_text(input_text)
print("\n\n", "="*100)
print("Transcribe Text: ", input_text)
print("="*100, "\n")
messages = [
{"role": "system", "content": self.system_prompt_speech},
{"role": "user", "content": input_text},
]
if input_text.lower().strip() != "":
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)
# Generate response
chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part
if generated_text.strip() != "":
print("\n\n", "="*100)
print("Address Extracted: ", generated_text)
print("="*100, "\n\n")
def main():
address_extract = AddressExtractor()
input_data = ""
while input_data.strip() != "exit":
input_data = input("Paste audio path or Text (type `exit` to quit): ")
if input_data.strip() == "exit":
sys.exit(0)
audio_path = ""
input_text = ""
if input_data.strip().endswith(".wav"):
audio_path = input_data.strip()
if not path.exists(audio_path):
print(f"Error: The audio file '{audio_path}' does not exist.")
else:
address_extract.infer_audio_sample(audio_path)
elif input_data.strip() != "":
input_text = input_data.strip()
address_extract.infer_text_sample(input_text)
else:
print("Error: Please provide the valid input")
if __name__ == "__main__":
main() |