import gradio as gr from address_extractor import AddressExtractor import tempfile import os import librosa import soundfile as sf # Instantiate your AddressExtractor class address_extractor = AddressExtractor() def extract_from_text(input_text): if not input_text.strip(): return "Error: No text provided." messages = [ {"role": "system", "content": address_extractor.system_prompt_text}, {"role": "user", "content": input_text}, ] prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device) chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256) generated_text = address_extractor.tokenizer.decode( chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True ) return generated_text.strip() or "No address detected." def extract_from_audio(audio_file): if audio_file is None: return "Error: No audio provided." # with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: # tmp_file.write(audio_file.read()) # tmp_file_path = tmp_file.name try: audio, sr = librosa.load(audio_file, sr=16000) sf.write(audio_file, audio, 16000) # segments = address_extractor.whisper_model.transcribe(tmp_file_path) segments = address_extractor.whisper_model.transcribe(audio_file) input_text = " ".join([seg.text.strip() for seg in segments]) input_text = address_extractor.preprocess_text(input_text) messages = [ {"role": "system", "content": address_extractor.system_prompt_speech}, {"role": "user", "content": input_text}, ] prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device) chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256) generated_text = address_extractor.tokenizer.decode( chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True ) result = generated_text.strip() or "No address detected." finally: # os.remove(tmp_file_path) pass return result # Gradio UI with gr.Blocks() as demo: gr.Markdown("## 📦 US Address Extractor") with gr.Tab("Text Input"): text_input = gr.Textbox(lines=3, label="Enter Text") text_output = gr.Textbox(label="Extracted Address") text_button = gr.Button("Extract Address") text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output) with gr.Tab("Audio Input (.wav)"): audio_input = gr.Audio(type="filepath", label="Upload a .wav Audio File") audio_output = gr.Textbox(label="Extracted Address") audio_button = gr.Button("Extract Address") audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output) demo.launch(server_name="0.0.0.0", share=True, server_port=7860)