import torch from transformers import AutoModelForCausalLM, AutoTokenizer import time, sys from os import path, listdir from pywhispercpp.model import Model class AddressExtractor(): def __init__(self): model_id = "microsoft/bitnet-b1.58-2B-4T" # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.bitnet_model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map = "cpu", ) # Set pad_token_id to eos_token_id self.tokenizer.pad_token = self.tokenizer.eos_token self.bitnet_model.config.pad_token_id = self.tokenizer.pad_token_id self.whisper_model = Model('small.en-q5_1', n_threads = 16, language = 'en') # self.whisper_model = Model('small.en', n_threads = 16, language = 'en') # self.whisper_model = Model('tiny.en', n_threads = 16, language = 'en') self.system_prompt_speech = """ Your task is to extract the US address given the ASR inferred text (using whisper-large-v3-turbo model) without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content. """ self.system_prompt_text = """ Your task is to extract the US address given the input text without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content. """ # self.sample_files_path = "./one_sentence_us_address/" def compute_latency(self, start_time, end_time): tr_duration= end_time-start_time hours = tr_duration // 3600 minutes = (tr_duration - (hours * 3600)) // 60 seconds = tr_duration - ((hours * 3600) + (minutes * 60)) msg = f'inference elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds' return msg def infer_text_sample(self, input_text): messages = [ {"role": "system", "content": self.system_prompt_text}, {"role": "user", "content": input_text}, ] if input_text.lower().strip() != "": prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device) # Generate response chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256) generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part if generated_text.strip() != "": print("\n\n", "="*100) print("Address Extracted: ", generated_text) print("="*100, "\n\n") def preprocess_text(self, input_text): ### Preprocessing the ASR generated text input_tokens = [] for word in input_text.split(" "): word = word.strip() if word != "": if "," in word: try: num = int(word) word = word.replace(",", " ") except: word = word.replace(",", ", ") input_tokens.append(word) input_text = " ".join(input_tokens) return input_text def infer_audio_sample(self, audio_input_file_path): input_text = "" segments = self.whisper_model.transcribe(audio_input_file_path) for segment in segments: input_text += segment.text.strip() input_text = self.preprocess_text(input_text) print("\n\n", "="*100) print("Transcribe Text: ", input_text) print("="*100, "\n") messages = [ {"role": "system", "content": self.system_prompt_speech}, {"role": "user", "content": input_text}, ] if input_text.lower().strip() != "": prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device) # Generate response chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256) generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part if generated_text.strip() != "": print("\n\n", "="*100) print("Address Extracted: ", generated_text) print("="*100, "\n\n") def main(): address_extract = AddressExtractor() input_data = "" while input_data.strip() != "exit": input_data = input("Paste audio path or Text (type `exit` to quit): ") if input_data.strip() == "exit": sys.exit(0) audio_path = "" input_text = "" if input_data.strip().endswith(".wav"): audio_path = input_data.strip() if not path.exists(audio_path): print(f"Error: The audio file '{audio_path}' does not exist.") else: address_extract.infer_audio_sample(audio_path) elif input_data.strip() != "": input_text = input_data.strip() address_extract.infer_text_sample(input_text) else: print("Error: Please provide the valid input") if __name__ == "__main__": main()