|
|
import gc |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import tempfile |
|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from flashpack import FlashPackMixin |
|
|
from huggingface_hub import Repository |
|
|
from typing import Tuple |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
device = torch.device("cpu") |
|
|
torch.set_num_threads(4) |
|
|
print(f"🔧 Using device: {device} (CPU-only)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GemmaTrainer(nn.Module, FlashPackMixin): |
|
|
def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
|
self.relu = nn.ReLU() |
|
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.fc3 = nn.Linear(hidden_dim, output_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.fc1(x) |
|
|
x = self.relu(x) |
|
|
x = self.fc2(x) |
|
|
x = self.relu(x) |
|
|
x = self.fc3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_encoder(model_name="gpt2", max_length=128): |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
embed_model = AutoModel.from_pretrained(model_name).to(device) |
|
|
embed_model.eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_batch(prompts: list, batch_size=16) -> torch.Tensor: |
|
|
embeddings = [] |
|
|
for i in range(0, len(prompts), batch_size): |
|
|
batch = prompts[i:i+batch_size] |
|
|
inputs = tokenizer(batch, return_tensors="pt", truncation=True, |
|
|
padding="max_length", max_length=max_length).to(device) |
|
|
last_hidden = embed_model(**inputs).last_hidden_state |
|
|
mean_pool = last_hidden.mean(dim=1) |
|
|
max_pool, _ = last_hidden.max(dim=1) |
|
|
batch_emb = torch.cat([mean_pool, max_pool], dim=1) |
|
|
embeddings.append(batch_emb.cpu()) |
|
|
return torch.vstack(embeddings) |
|
|
|
|
|
return tokenizer, embed_model, encode_batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def push_flashpack_model_to_hf(model, hf_repo: str): |
|
|
logs = [] |
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
logs.append(f"📂 Using temporary directory: {tmp_dir}") |
|
|
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) |
|
|
pack_path = os.path.join(tmp_dir, "model.flashpack") |
|
|
model.save_flashpack(pack_path, target_dtype=torch.float32) |
|
|
readme_path = os.path.join(tmp_dir, "README.md") |
|
|
with open(readme_path, "w") as f: |
|
|
f.write("# FlashPack Model\nThis repo contains a FlashPack model trained for short→long prompt mapping.") |
|
|
repo.push_to_hub() |
|
|
logs.append(f"✅ Model pushed to Hugging Face repo: {hf_repo}") |
|
|
return logs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_flashpack_model( |
|
|
dataset_name="rahul7star/prompt-enhancer-dataset", |
|
|
max_encode=1000, |
|
|
hidden_dim=1024, |
|
|
hf_repo="rahul7star/FlashPack", |
|
|
push_to_hub=True, |
|
|
test_split=0.1, |
|
|
batch_size=32, |
|
|
max_epochs=50, |
|
|
target_test_loss=0.01 |
|
|
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]: |
|
|
|
|
|
print("📦 Loading dataset...") |
|
|
dataset = load_dataset(dataset_name, split="train") |
|
|
limit = min(max_encode, len(dataset)) |
|
|
dataset = dataset.select(range(limit)) |
|
|
print(f"⚡ Using {len(dataset)} prompts for training") |
|
|
|
|
|
short_prompts = [item["short_prompt"] for item in dataset] |
|
|
long_prompts = [item["long_prompt"] for item in dataset] |
|
|
|
|
|
|
|
|
train_short, test_short, train_long, test_long = train_test_split( |
|
|
short_prompts, long_prompts, test_size=test_split, random_state=42 |
|
|
) |
|
|
print(f"🔹 Train size: {len(train_short)}, Test size: {len(test_short)}") |
|
|
|
|
|
tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128) |
|
|
|
|
|
|
|
|
print("⚡ Encoding training prompts...") |
|
|
train_short_emb = encode_batch(train_short) |
|
|
train_long_emb = encode_batch(train_long) |
|
|
print(f"✅ Train embeddings shape: {train_short_emb.shape}, {train_long_emb.shape}") |
|
|
|
|
|
print("⚡ Encoding test prompts...") |
|
|
test_short_emb = encode_batch(test_short) |
|
|
test_long_emb = encode_batch(test_long) |
|
|
print(f"✅ Test embeddings shape: {test_short_emb.shape}, {test_long_emb.shape}") |
|
|
|
|
|
input_dim = train_short_emb.shape[1] |
|
|
output_dim = train_long_emb.shape[1] |
|
|
|
|
|
model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device) |
|
|
|
|
|
criterion = nn.CosineSimilarity(dim=1) |
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
|
|
n_train = train_short_emb.shape[0] |
|
|
|
|
|
print("🚀 Training model...") |
|
|
for epoch in range(max_epochs): |
|
|
model.train() |
|
|
epoch_loss = 0.0 |
|
|
perm = torch.randperm(n_train) |
|
|
for start in range(0, n_train, batch_size): |
|
|
idx = perm[start:start+batch_size] |
|
|
inputs = train_short_emb[idx].to(device) |
|
|
targets = train_long_emb[idx].to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(inputs) |
|
|
loss = 1 - criterion(outputs, targets).mean() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
epoch_loss += loss.item() * inputs.size(0) |
|
|
|
|
|
epoch_loss /= n_train |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
test_outputs = model(test_short_emb.to(device)) |
|
|
test_loss = (1 - criterion(test_outputs, test_long_emb.to(device)).mean()).item() |
|
|
|
|
|
print(f"Epoch {epoch+1}/{max_epochs} → Train loss: {epoch_loss:.6f}, Test loss: {test_loss:.6f}") |
|
|
|
|
|
|
|
|
if test_loss <= target_test_loss: |
|
|
print(f"✅ Target test loss reached ({test_loss:.6f}) – stopping training early.") |
|
|
break |
|
|
|
|
|
|
|
|
logs = [] |
|
|
if push_to_hub and test_loss <= target_test_loss: |
|
|
logs = push_flashpack_model_to_hf(model, hf_repo) |
|
|
for log in logs: |
|
|
print(log) |
|
|
elif push_to_hub: |
|
|
print(f"⚠️ Test loss too high ({test_loss:.6f}); skipping HF upload.") |
|
|
|
|
|
return model, dataset, embed_model, tokenizer, train_long_emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_flashpack_model(hf_repo="rahul7star/FlashPack"): |
|
|
try: |
|
|
print(f"🔁 Attempting to load FlashPack model from {hf_repo}") |
|
|
model = GemmaTrainer.from_flashpack(hf_repo) |
|
|
model.eval() |
|
|
tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128) |
|
|
return model, tokenizer, embed_model |
|
|
except Exception as e: |
|
|
print(f"⚠️ Load failed: {e}") |
|
|
print("⏬ Training a new FlashPack model locally...") |
|
|
return train_flashpack_model(hf_repo=hf_repo) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_for_inference(prompt: str) -> torch.Tensor: |
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
|
|
padding="max_length", max_length=128).to(device) |
|
|
last_hidden = embed_model(**inputs).last_hidden_state |
|
|
mean_pool = last_hidden.mean(dim=1) |
|
|
max_pool, _ = last_hidden.max(dim=1) |
|
|
return torch.cat([mean_pool, max_pool], dim=1).cpu() |
|
|
|
|
|
def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history): |
|
|
chat_history = chat_history or [] |
|
|
short_emb = encode_for_inference(user_prompt) |
|
|
mapped = model(short_emb.to(device)).cpu() |
|
|
|
|
|
sims = (long_embeddings @ mapped.t()).squeeze(1) |
|
|
long_norms = long_embeddings.norm(dim=1) |
|
|
mapped_norm = mapped.norm() |
|
|
sims = sims / (long_norms * (mapped_norm + 1e-12)) |
|
|
|
|
|
best_idx = int(sims.argmax().item()) |
|
|
enhanced_prompt = dataset[best_idx]["long_prompt"] |
|
|
|
|
|
chat_history.append({"role": "user", "content": user_prompt}) |
|
|
chat_history.append({"role": "assistant", "content": enhanced_prompt}) |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ✨ Prompt Enhancer (FlashPack mapper) |
|
|
Enter a short prompt, and the model will **expand it with details and creative context**. |
|
|
(CPU-only mode.) |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") |
|
|
with gr.Column(scale=1): |
|
|
user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3) |
|
|
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
|
|
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") |
|
|
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") |
|
|
clear_btn = gr.Button("🧹 Clear Chat") |
|
|
|
|
|
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
|
|
clear_btn.click(lambda: [], None, chatbot) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|