researc_paper / app.py
Akshayram1's picture
Update app.py
49671e5 verified
import gradio as gr
import os
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_mistralai import MistralAIEmbeddings
from langchain_groq import ChatGroq
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from pypdf import PdfReader
from docx import Document
import time
import fitz # PyMuPDF
# Configuration for batch processing and rate limiting
BATCH_SIZE = 20 # Reduced for embeddings (was 50)
BATCH_DELAY = 1.0 # Increased delay for API rate limiting (was 0.2)
MAX_RETRIES = 3 # Increased retries for API stability
CACHE_ENABLED = True # Enable caching for embeddings
# Load environment variables from .env file
load_dotenv()
def process_chunks_in_batches_with_progress(chunks, metadatas, file_name, progress_callback=None):
"""Process document chunks in optimized batches with minimal delays."""
total_chunks = len(chunks)
processed = 0
print(f"Processing {total_chunks} chunks in batches of {BATCH_SIZE}")
if progress_callback:
progress_callback(f"Starting optimized batch processing: {total_chunks} chunks total")
for i in range(0, total_chunks, BATCH_SIZE):
batch_end = min(i + BATCH_SIZE, total_chunks)
batch_chunks = chunks[i:batch_end]
batch_metadatas = metadatas[i:batch_end]
# Enhanced retry logic with exponential backoff
success = False
for attempt in range(MAX_RETRIES):
try:
batch_start_time = time.time()
vectorstore.add_texts(batch_chunks, metadatas=batch_metadatas)
batch_time = time.time() - batch_start_time
processed += len(batch_chunks)
progress_msg = f"βœ“ Batch {i//BATCH_SIZE + 1}/{(total_chunks-1)//BATCH_SIZE + 1}: {len(batch_chunks)} chunks ({processed}/{total_chunks}) - {batch_time:.2f}s"
print(progress_msg)
if progress_callback:
progress_callback(progress_msg)
success = True
break
except Exception as e:
wait_time = min(2 ** attempt, 10) # Exponential backoff, max 10s
if "rate limit" in str(e).lower() or "quota" in str(e).lower():
print(f"Rate limit hit, waiting {wait_time}s...")
time.sleep(wait_time)
elif attempt < MAX_RETRIES - 1:
print(f"Attempt {attempt + 1} failed, retrying in {wait_time}s...")
time.sleep(wait_time)
else:
print(f"βœ— Failed batch {i//BATCH_SIZE + 1} after {MAX_RETRIES} attempts: {str(e)[:100]}...")
# Respectful delay between successful batches to avoid rate limits
if success and batch_end < total_chunks:
time.sleep(BATCH_DELAY)
print(f"βœ“ Processing complete: {processed}/{total_chunks} chunks processed")
return processed
# Initialize optimized embeddings with performance settings
try:
embeddings = MistralAIEmbeddings(
model="mistral-embed", # Specify model for consistency
# Add timeout and retry settings
)
print("βœ“ MistralAI embeddings initialized successfully")
except Exception as e:
print(f"βœ— Error initializing MistralAI embeddings: {e}")
print("Retrying with default settings...")
embeddings = MistralAIEmbeddings()
# Initialize optimized vector store
vectorstore = Chroma(
collection_name="research_papers",
embedding_function=embeddings,
persist_directory="./chroma_db",
# Collection metadata for optimization
collection_metadata={"hnsw:space": "cosine"} # Faster similarity search
)
# Initialize optimized text splitter for faster processing
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400, # Reduced for faster embedding (was 600)
chunk_overlap=50, # Minimal overlap for faster processing
length_function=len,
separators=["\n\n", "\n", ".", " ", ""] # Simplified separators for speed
)
def parse_pdf(file_path):
"""Parse PDF file and extract text using PyMuPDF."""
text = ""
try:
# Try PyMuPDF first (better extraction)
doc = fitz.open(file_path)
for page in doc:
text += page.get_text() + "\n"
doc.close()
except Exception as e:
print(f"PyMuPDF failed: {e}, falling back to pypdf")
# Fallback to pypdf
reader = PdfReader(file_path)
for page in reader.pages:
text += page.extract_text() + "\n"
return text
def parse_docx(file_path):
"""Parse DOCX file and extract text."""
doc = Document(file_path)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text
def process_document(file):
"""Process uploaded document and add to vector store."""
if file is None:
return "No file uploaded."
# Get the file path from Gradio's file input
file_path = file.name if hasattr(file, 'name') else str(file)
try:
# Determine file type and parse
if file_path.lower().endswith('.pdf'):
text = parse_pdf(file_path)
elif file_path.lower().endswith('.docx'):
text = parse_docx(file_path)
else:
return "Unsupported file format. Please upload PDF or DOCX files."
# Debug: Print text length and preview
print(f"Extracted text length: {len(text)} characters")
print(f"Text preview: {text[:500]}...")
# Split text into chunks
chunks = text_splitter.split_text(text)
# Debug: Print chunk information
print(f"Number of chunks: {len(chunks)}")
print(f"Average chunk length: {sum(len(chunk) for chunk in chunks) / len(chunks) if chunks else 0}")
# Prepare metadata for chunks
metadatas = [{"source": os.path.basename(file_path)} for _ in chunks]
# Add to vector store in batches
try:
processed_chunks = process_chunks_in_batches_with_progress(chunks, metadatas, os.path.basename(file_path))
print(f"Successfully processed {processed_chunks}/{len(chunks)} chunks")
except Exception as e:
print(f"Error in batch processing: {e}")
return f"Error storing document chunks: {str(e)}"
return f"Successfully processed {os.path.basename(file_path)} with {len(chunks)} chunks ({processed_chunks} stored). Total text length: {len(text)} characters."
except Exception as e:
return f"Error processing document: {str(e)}"
def query_documents(question):
"""Query the documents using RAG with performance monitoring."""
start_time = time.time()
try:
# Check if there are documents
if vectorstore._collection.count() == 0:
return "No documents uploaded yet. Please upload some research papers first."
except:
return "No documents uploaded yet. Please upload some research papers first."
# Initialize LLM
try:
llm = ChatGroq(model="llama-3.1-8b-instant")
print("βœ“ Llama 3.1 8B Instant LLM initialized successfully")
except Exception as e:
print(f"βœ— Error initializing LLM: {e}")
return f"Error initializing language model: {str(e)}"
# Create optimized RAG chain with faster retrieval
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={
"k": 3 # Reduced back to 3 for faster retrieval
}
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False # Faster response
)
# Get answer with error handling and timing
try:
retrieval_start = time.time()
result = qa_chain.invoke({"query": question})
total_time = time.time() - start_time
retrieval_time = time.time() - retrieval_start
print(f"βœ“ Question answered in {total_time:.2f}s (retrieval: {retrieval_time:.2f}s)")
return result["result"] if isinstance(result, dict) else str(result)
except Exception as e:
print(f"βœ— Error generating answer: {e}")
if "rate limit" in str(e).lower() or "quota" in str(e).lower():
return "API rate limit reached. Please wait a moment and try again."
if "fetch_k" in str(e).lower():
return "Vector database configuration error. Please restart the application."
return f"Error generating answer: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Research Paper RAG System") as iface:
gr.Markdown("# πŸš€ Optimized Research Paper RAG System")
gr.Markdown("Upload PDF or DOCX research papers and ask questions about them.")
gr.Markdown(f"**⚑ Performance Optimized**: {BATCH_SIZE}-chunk batches, {BATCH_DELAY}s delays, optimized retrieval.")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload Research Paper (PDF/DOCX)", file_types=[".pdf", ".docx"])
upload_button = gr.Button("Process Document")
upload_output = gr.Textbox(label="Upload Status", interactive=False)
progress_bar = gr.Textbox(label="Processing Progress", interactive=False, visible=False)
with gr.Column():
question_input = gr.Textbox(label="Ask a question about the uploaded papers", lines=3)
query_button = gr.Button("Ask Question")
answer_output = gr.Textbox(label="Answer", lines=10, interactive=False)
upload_button.click(process_document, inputs=file_input, outputs=upload_output)
query_button.click(query_documents, inputs=question_input, outputs=answer_output)
if __name__ == "__main__":
iface.launch()