rahul7star's picture
Create app_flash1.py
d9e93e9 verified
raw
history blame
10.2 kB
import gc
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository
from typing import Tuple
from sklearn.model_selection import train_test_split
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"🔧 Using device: {device} (CPU-only)")
# ============================================================
# 1️⃣ Model
# ============================================================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ============================================================
# 2️⃣ Encoder with batch mean+max pooling
# ============================================================
def build_encoder(model_name="gpt2", max_length=128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode_batch(prompts: list, batch_size=16) -> torch.Tensor:
embeddings = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
inputs = tokenizer(batch, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
batch_emb = torch.cat([mean_pool, max_pool], dim=1)
embeddings.append(batch_emb.cpu())
return torch.vstack(embeddings)
return tokenizer, embed_model, encode_batch
# ============================================================
# 3️⃣ Push model to HF
# ============================================================
def push_flashpack_model_to_hf(model, hf_repo: str):
logs = []
with tempfile.TemporaryDirectory() as tmp_dir:
logs.append(f"📂 Using temporary directory: {tmp_dir}")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
pack_path = os.path.join(tmp_dir, "model.flashpack")
model.save_flashpack(pack_path, target_dtype=torch.float32)
readme_path = os.path.join(tmp_dir, "README.md")
with open(readme_path, "w") as f:
f.write("# FlashPack Model\nThis repo contains a FlashPack model trained for short→long prompt mapping.")
repo.push_to_hub()
logs.append(f"✅ Model pushed to Hugging Face repo: {hf_repo}")
return logs
# ============================================================
# 4️⃣ Train with train/test split & detailed logging
# ============================================================
def train_flashpack_model(
dataset_name="rahul7star/prompt-enhancer-dataset",
max_encode=1000,
hidden_dim=1024,
hf_repo="rahul7star/FlashPack",
push_to_hub=True,
test_split=0.1,
batch_size=32,
max_epochs=50,
target_test_loss=0.01
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
print("📦 Loading dataset...")
dataset = load_dataset(dataset_name, split="train")
limit = min(max_encode, len(dataset))
dataset = dataset.select(range(limit))
print(f"⚡ Using {len(dataset)} prompts for training")
short_prompts = [item["short_prompt"] for item in dataset]
long_prompts = [item["long_prompt"] for item in dataset]
# Split
train_short, test_short, train_long, test_long = train_test_split(
short_prompts, long_prompts, test_size=test_split, random_state=42
)
print(f"🔹 Train size: {len(train_short)}, Test size: {len(test_short)}")
tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
# Encode
print("⚡ Encoding training prompts...")
train_short_emb = encode_batch(train_short)
train_long_emb = encode_batch(train_long)
print(f"✅ Train embeddings shape: {train_short_emb.shape}, {train_long_emb.shape}")
print("⚡ Encoding test prompts...")
test_short_emb = encode_batch(test_short)
test_long_emb = encode_batch(test_long)
print(f"✅ Test embeddings shape: {test_short_emb.shape}, {test_long_emb.shape}")
input_dim = train_short_emb.shape[1]
output_dim = train_long_emb.shape[1]
model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
criterion = nn.CosineSimilarity(dim=1)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_train = train_short_emb.shape[0]
print("🚀 Training model...")
for epoch in range(max_epochs):
model.train()
epoch_loss = 0.0
perm = torch.randperm(n_train)
for start in range(0, n_train, batch_size):
idx = perm[start:start+batch_size]
inputs = train_short_emb[idx].to(device)
targets = train_long_emb[idx].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = 1 - criterion(outputs, targets).mean()
loss.backward()
optimizer.step()
epoch_loss += loss.item() * inputs.size(0)
epoch_loss /= n_train
# Evaluate on test
model.eval()
with torch.no_grad():
test_outputs = model(test_short_emb.to(device))
test_loss = (1 - criterion(test_outputs, test_long_emb.to(device)).mean()).item()
print(f"Epoch {epoch+1}/{max_epochs} → Train loss: {epoch_loss:.6f}, Test loss: {test_loss:.6f}")
# Check if model is perfect enough
if test_loss <= target_test_loss:
print(f"✅ Target test loss reached ({test_loss:.6f}) – stopping training early.")
break
# Push to HF if trained well
logs = []
if push_to_hub and test_loss <= target_test_loss:
logs = push_flashpack_model_to_hf(model, hf_repo)
for log in logs:
print(log)
elif push_to_hub:
print(f"⚠️ Test loss too high ({test_loss:.6f}); skipping HF upload.")
return model, dataset, embed_model, tokenizer, train_long_emb
# ============================================================
# 5️⃣ Load or train
# ============================================================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
try:
print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
model = GemmaTrainer.from_flashpack(hf_repo)
model.eval()
tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
return model, tokenizer, embed_model
except Exception as e:
print(f"⚠️ Load failed: {e}")
print("⏬ Training a new FlashPack model locally...")
return train_flashpack_model(hf_repo=hf_repo)
# ============================================================
# 6️⃣ Load or train
# ============================================================
model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
# ============================================================
# 7️⃣ Inference helpers
# ============================================================
@torch.no_grad()
def encode_for_inference(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=128).to(device)
last_hidden = embed_model(**inputs).last_hidden_state
mean_pool = last_hidden.mean(dim=1)
max_pool, _ = last_hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu()
def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
chat_history = chat_history or []
short_emb = encode_for_inference(user_prompt)
mapped = model(short_emb.to(device)).cpu()
sims = (long_embeddings @ mapped.t()).squeeze(1)
long_norms = long_embeddings.norm(dim=1)
mapped_norm = mapped.norm()
sims = sims / (long_norms * (mapped_norm + 1e-12))
best_idx = int(sims.argmax().item())
enhanced_prompt = dataset[best_idx]["long_prompt"]
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced_prompt})
return chat_history
# ============================================================
# 8️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (FlashPack mapper)
Enter a short prompt, and the model will **expand it with details and creative context**.
(CPU-only mode.)
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
if __name__ == "__main__":
demo.launch(show_error=True)