File size: 5,693 Bytes
edf14f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time, sys
from os import path, listdir
from pywhispercpp.model import Model


class AddressExtractor():
    def __init__(self):
        model_id = "microsoft/bitnet-b1.58-2B-4T"

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.bitnet_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map = "cpu",
        )

        # Set pad_token_id to eos_token_id
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.bitnet_model.config.pad_token_id = self.tokenizer.pad_token_id


        self.whisper_model = Model('small.en-q5_1', n_threads = 16, language = 'en')
        # self.whisper_model = Model('small.en', n_threads = 16, language = 'en')
        # self.whisper_model = Model('tiny.en', n_threads = 16, language = 'en')

        self.system_prompt_speech = """
            Your task is to extract the US address given the ASR inferred text (using whisper-large-v3-turbo model) without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
        """

        self.system_prompt_text = """
            Your task is to extract the US address given the input text without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
        """

        # self.sample_files_path = "./one_sentence_us_address/"



    def compute_latency(self, start_time, end_time):
        tr_duration= end_time-start_time
        hours = tr_duration // 3600
        minutes = (tr_duration - (hours * 3600)) // 60
        seconds = tr_duration - ((hours * 3600) + (minutes * 60))
        msg = f'inference elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds'

        return msg
    

    def infer_text_sample(self, input_text):

        messages = [
            {"role": "system", "content": self.system_prompt_text},
            {"role": "user", "content": input_text},
        ]

        if input_text.lower().strip() != "":
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)

            # Generate response
            chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
            generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part

            if generated_text.strip() != "":
                print("\n\n", "="*100)
                print("Address Extracted: ", generated_text)
                print("="*100, "\n\n")
    

    def preprocess_text(self, input_text):
        ### Preprocessing the ASR generated text
        input_tokens = []
        for word in input_text.split(" "):
            word = word.strip()
            if word != "":
                if "," in word:
                    try:
                        num = int(word)
                        word = word.replace(",", " ")
                    except:
                        word = word.replace(",", ", ")
                input_tokens.append(word)
        input_text = " ".join(input_tokens)
        
        return input_text



    def infer_audio_sample(self, audio_input_file_path):
        input_text = ""
        segments = self.whisper_model.transcribe(audio_input_file_path)
        for segment in segments:
            input_text += segment.text.strip()

        input_text = self.preprocess_text(input_text)

        print("\n\n", "="*100)
        print("Transcribe Text: ", input_text)
        print("="*100, "\n")

        messages = [
            {"role": "system", "content": self.system_prompt_speech},
            {"role": "user", "content": input_text},
        ]

        if input_text.lower().strip() != "":
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)

            # Generate response
            chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
            generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part

            if generated_text.strip() != "":
                print("\n\n", "="*100)
                print("Address Extracted: ", generated_text)
                print("="*100, "\n\n")



def main():

    address_extract = AddressExtractor()

    input_data = ""
    while input_data.strip() != "exit":
        input_data = input("Paste audio path or Text (type `exit` to quit): ")

        if input_data.strip() == "exit":
            sys.exit(0)

        audio_path = ""
        input_text = ""
        if input_data.strip().endswith(".wav"):
            audio_path = input_data.strip()
            if not path.exists(audio_path):
                print(f"Error: The audio file '{audio_path}' does not exist.")
            
            else:
                address_extract.infer_audio_sample(audio_path)
            

        elif input_data.strip() != "":
            input_text = input_data.strip()
            address_extract.infer_text_sample(input_text)

            
        else:
            print("Error: Please provide the valid input")


if __name__ == "__main__":
    main()