Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from dotenv import load_dotenv | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_mistralai import MistralAIEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain_chroma import Chroma | |
| from langchain.chains import RetrievalQA | |
| from pypdf import PdfReader | |
| from docx import Document | |
| import time | |
| import fitz # PyMuPDF | |
| # Configuration for batch processing and rate limiting | |
| BATCH_SIZE = 20 # Reduced for embeddings (was 50) | |
| BATCH_DELAY = 1.0 # Increased delay for API rate limiting (was 0.2) | |
| MAX_RETRIES = 3 # Increased retries for API stability | |
| CACHE_ENABLED = True # Enable caching for embeddings | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| def process_chunks_in_batches_with_progress(chunks, metadatas, file_name, progress_callback=None): | |
| """Process document chunks in optimized batches with minimal delays.""" | |
| total_chunks = len(chunks) | |
| processed = 0 | |
| print(f"Processing {total_chunks} chunks in batches of {BATCH_SIZE}") | |
| if progress_callback: | |
| progress_callback(f"Starting optimized batch processing: {total_chunks} chunks total") | |
| for i in range(0, total_chunks, BATCH_SIZE): | |
| batch_end = min(i + BATCH_SIZE, total_chunks) | |
| batch_chunks = chunks[i:batch_end] | |
| batch_metadatas = metadatas[i:batch_end] | |
| # Enhanced retry logic with exponential backoff | |
| success = False | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| batch_start_time = time.time() | |
| vectorstore.add_texts(batch_chunks, metadatas=batch_metadatas) | |
| batch_time = time.time() - batch_start_time | |
| processed += len(batch_chunks) | |
| progress_msg = f"β Batch {i//BATCH_SIZE + 1}/{(total_chunks-1)//BATCH_SIZE + 1}: {len(batch_chunks)} chunks ({processed}/{total_chunks}) - {batch_time:.2f}s" | |
| print(progress_msg) | |
| if progress_callback: | |
| progress_callback(progress_msg) | |
| success = True | |
| break | |
| except Exception as e: | |
| wait_time = min(2 ** attempt, 10) # Exponential backoff, max 10s | |
| if "rate limit" in str(e).lower() or "quota" in str(e).lower(): | |
| print(f"Rate limit hit, waiting {wait_time}s...") | |
| time.sleep(wait_time) | |
| elif attempt < MAX_RETRIES - 1: | |
| print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s...") | |
| time.sleep(wait_time) | |
| else: | |
| print(f"β Failed batch {i//BATCH_SIZE + 1} after {MAX_RETRIES} attempts: {str(e)[:100]}...") | |
| # Respectful delay between successful batches to avoid rate limits | |
| if success and batch_end < total_chunks: | |
| time.sleep(BATCH_DELAY) | |
| print(f"β Processing complete: {processed}/{total_chunks} chunks processed") | |
| return processed | |
| # Initialize optimized embeddings with performance settings | |
| try: | |
| embeddings = MistralAIEmbeddings( | |
| model="mistral-embed", # Specify model for consistency | |
| # Add timeout and retry settings | |
| ) | |
| print("β MistralAI embeddings initialized successfully") | |
| except Exception as e: | |
| print(f"β Error initializing MistralAI embeddings: {e}") | |
| print("Retrying with default settings...") | |
| embeddings = MistralAIEmbeddings() | |
| # Initialize optimized vector store | |
| vectorstore = Chroma( | |
| collection_name="research_papers", | |
| embedding_function=embeddings, | |
| persist_directory="./chroma_db", | |
| # Collection metadata for optimization | |
| collection_metadata={"hnsw:space": "cosine"} # Faster similarity search | |
| ) | |
| # Initialize optimized text splitter for faster processing | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=400, # Reduced for faster embedding (was 600) | |
| chunk_overlap=50, # Minimal overlap for faster processing | |
| length_function=len, | |
| separators=["\n\n", "\n", ".", " ", ""] # Simplified separators for speed | |
| ) | |
| def parse_pdf(file_path): | |
| """Parse PDF file and extract text using PyMuPDF.""" | |
| text = "" | |
| try: | |
| # Try PyMuPDF first (better extraction) | |
| doc = fitz.open(file_path) | |
| for page in doc: | |
| text += page.get_text() + "\n" | |
| doc.close() | |
| except Exception as e: | |
| print(f"PyMuPDF failed: {e}, falling back to pypdf") | |
| # Fallback to pypdf | |
| reader = PdfReader(file_path) | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def parse_docx(file_path): | |
| """Parse DOCX file and extract text.""" | |
| doc = Document(file_path) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| def process_document(file): | |
| """Process uploaded document and add to vector store.""" | |
| if file is None: | |
| return "No file uploaded." | |
| # Get the file path from Gradio's file input | |
| file_path = file.name if hasattr(file, 'name') else str(file) | |
| try: | |
| # Determine file type and parse | |
| if file_path.lower().endswith('.pdf'): | |
| text = parse_pdf(file_path) | |
| elif file_path.lower().endswith('.docx'): | |
| text = parse_docx(file_path) | |
| else: | |
| return "Unsupported file format. Please upload PDF or DOCX files." | |
| # Debug: Print text length and preview | |
| print(f"Extracted text length: {len(text)} characters") | |
| print(f"Text preview: {text[:500]}...") | |
| # Split text into chunks | |
| chunks = text_splitter.split_text(text) | |
| # Debug: Print chunk information | |
| print(f"Number of chunks: {len(chunks)}") | |
| print(f"Average chunk length: {sum(len(chunk) for chunk in chunks) / len(chunks) if chunks else 0}") | |
| # Prepare metadata for chunks | |
| metadatas = [{"source": os.path.basename(file_path)} for _ in chunks] | |
| # Add to vector store in batches | |
| try: | |
| processed_chunks = process_chunks_in_batches_with_progress(chunks, metadatas, os.path.basename(file_path)) | |
| print(f"Successfully processed {processed_chunks}/{len(chunks)} chunks") | |
| except Exception as e: | |
| print(f"Error in batch processing: {e}") | |
| return f"Error storing document chunks: {str(e)}" | |
| return f"Successfully processed {os.path.basename(file_path)} with {len(chunks)} chunks ({processed_chunks} stored). Total text length: {len(text)} characters." | |
| except Exception as e: | |
| return f"Error processing document: {str(e)}" | |
| def query_documents(question): | |
| """Query the documents using RAG with performance monitoring.""" | |
| start_time = time.time() | |
| try: | |
| # Check if there are documents | |
| if vectorstore._collection.count() == 0: | |
| return "No documents uploaded yet. Please upload some research papers first." | |
| except: | |
| return "No documents uploaded yet. Please upload some research papers first." | |
| # Initialize LLM | |
| try: | |
| llm = ChatGroq(model="llama-3.1-8b-instant") | |
| print("β Llama 3.1 8B Instant LLM initialized successfully") | |
| except Exception as e: | |
| print(f"β Error initializing LLM: {e}") | |
| return f"Error initializing language model: {str(e)}" | |
| # Create optimized RAG chain with faster retrieval | |
| retriever = vectorstore.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={ | |
| "k": 3 # Reduced back to 3 for faster retrieval | |
| } | |
| ) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=False # Faster response | |
| ) | |
| # Get answer with error handling and timing | |
| try: | |
| retrieval_start = time.time() | |
| result = qa_chain.invoke({"query": question}) | |
| total_time = time.time() - start_time | |
| retrieval_time = time.time() - retrieval_start | |
| print(f"β Question answered in {total_time:.2f}s (retrieval: {retrieval_time:.2f}s)") | |
| return result["result"] if isinstance(result, dict) else str(result) | |
| except Exception as e: | |
| print(f"β Error generating answer: {e}") | |
| if "rate limit" in str(e).lower() or "quota" in str(e).lower(): | |
| return "API rate limit reached. Please wait a moment and try again." | |
| if "fetch_k" in str(e).lower(): | |
| return "Vector database configuration error. Please restart the application." | |
| return f"Error generating answer: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Research Paper RAG System") as iface: | |
| gr.Markdown("# π Optimized Research Paper RAG System") | |
| gr.Markdown("Upload PDF or DOCX research papers and ask questions about them.") | |
| gr.Markdown(f"**β‘ Performance Optimized**: {BATCH_SIZE}-chunk batches, {BATCH_DELAY}s delays, optimized retrieval.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File(label="Upload Research Paper (PDF/DOCX)", file_types=[".pdf", ".docx"]) | |
| upload_button = gr.Button("Process Document") | |
| upload_output = gr.Textbox(label="Upload Status", interactive=False) | |
| progress_bar = gr.Textbox(label="Processing Progress", interactive=False, visible=False) | |
| with gr.Column(): | |
| question_input = gr.Textbox(label="Ask a question about the uploaded papers", lines=3) | |
| query_button = gr.Button("Ask Question") | |
| answer_output = gr.Textbox(label="Answer", lines=10, interactive=False) | |
| upload_button.click(process_document, inputs=file_input, outputs=upload_output) | |
| query_button.click(query_documents, inputs=question_input, outputs=answer_output) | |
| if __name__ == "__main__": | |
| iface.launch() |