SamerPF's picture
Update rag_agent.py
33384e5 verified
raw
history blame
2.42 kB
import os
import requests
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from transformers import AutoTokenizer, AutoModelForCausalLM
export HUGGINGFACEHUB_API_TOKEN=HF_TOKEN
from huggingface_hub import login
login()
# ==== 1. Set up Hugging Face Embedding Model ====
# Use HuggingFaceEmbedding from llama_index directly
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# ==== 2. Load Hugging Face LLM (Locally Installed or Remote Hosted) ====
llm = HuggingFaceLLM(
model_name="mistralai/Mistral-7B-Instruct-v0.3", # You must have access!
tokenizer_name="mistralai/Mistral-7B-Instruct-v0.3",
context_window=2048,
max_new_tokens=512,
generate_kwargs={"temperature": 0.1},
tokenizer_kwargs={"padding_side": "left"},
device_map="auto",
)
Settings.llm = llm # Apply to global settings
# ==== 3. Validate & Download ArXiv PDFs (if needed) ====
def download_pdf(arxiv_id, save_dir="kb"):
url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
response = requests.get(url)
if "application/pdf" in response.headers.get("Content-Type", ""):
os.makedirs(save_dir, exist_ok=True)
file_path = os.path.join(save_dir, f"{arxiv_id}.pdf")
with open(file_path, "wb") as f:
f.write(response.content)
print(f"✅ Downloaded {file_path}")
else:
print(f"❌ Failed to download PDF for {arxiv_id}: Not a valid PDF")
# Example: download_pdf("2312.03840")
# ==== 4. Load Knowledge Base ====
documents = SimpleDirectoryReader("kb", required_exts=[".pdf"]).load_data()
index = VectorStoreIndex.from_documents(documents)
# ==== 5. Create Query Engine ====
query_engine = index.as_query_engine()
# ==== 6. Wrap as a Tool ====
rag_tool = QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(name="RAGSearch", description="Answers from a local HF-based RAG system.")
)
# ==== 7. Basic Agent ====
class BasicAgent:
def __init__(self):
self.tool = rag_tool
def __call__(self, question: str) -> str:
print(f"🧠 RAG Agent received: {question}")
response = self.tool.query_engine.query(question)
return str(response)