rahul7star's picture
Update app_flash1.py
b2330bc verified
raw
history blame
7.72 kB
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import tempfile
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from flashpack import FlashPackMixin
from huggingface_hub import Repository, list_repo_files, hf_hub_download
device = torch.device("cpu")
torch.set_num_threads(4)
print(f"πŸ”§ Using device: {device} (CPU-only mode)")
# ===========================
# Model Definition
# ===========================
class GemmaTrainer(nn.Module, FlashPackMixin):
def __init__(self):
super().__init__()
input_dim = 1536
hidden_dim = 1024
output_dim = 1536
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# ===========================
# Encoder
# ===========================
def build_encoder(model_name="gpt2", max_length=128):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
embed_model = AutoModel.from_pretrained(model_name).to(device)
embed_model.eval()
@torch.no_grad()
def encode(prompt: str) -> torch.Tensor:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_length).to(device)
hidden = embed_model(**inputs).last_hidden_state
mean_pool = hidden.mean(dim=1)
max_pool, _ = hidden.max(dim=1)
return torch.cat([mean_pool, max_pool], dim=1).cpu()
return tokenizer, embed_model, encode
# ===========================
# Push model to HF
# ===========================
def push_flashpack_model_to_hf(model, hf_repo, log_fn):
with tempfile.TemporaryDirectory() as tmp_dir:
log_fn(f"πŸ“¦ Preparing repository {hf_repo}...")
repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
with open(os.path.join(tmp_dir, "README.md"), "w") as f:
f.write("# FlashPack Model\nTrained locally and pushed to HF.")
log_fn("⏳ Pushing model to Hugging Face...")
repo.push_to_hub()
log_fn(f"βœ… Model pushed to {hf_repo}")
# ===========================
# Training
# ===========================
def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
hf_repo="rahul7star/FlashPack",
max_encode=1000):
logs = []
def log_fn(msg):
logs.append(msg)
print(msg)
log_fn("πŸ“¦ Loading dataset...")
dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
log_fn(f"βœ… Loaded {len(dataset)} samples")
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
# Only encode short+long embeddings
s_list, l_list = [], []
for i, item in enumerate(dataset):
s_list.append(encode_fn(item["short_prompt"]))
l_list.append(encode_fn(item["long_prompt"]))
if (i + 1) % 50 == 0:
log_fn(f" β†’ Encoded {i + 1}/{len(dataset)}")
gc.collect()
short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list)
model = GemmaTrainer()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CosineSimilarity(dim=1)
log_fn("πŸš€ Training model...")
for epoch in range(20):
model.train()
optimizer.zero_grad()
preds = model(short_emb)
loss = 1 - loss_fn(preds, long_emb).mean()
loss.backward()
optimizer.step()
log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
if loss.item() < 0.01:
log_fn("🎯 Early stopping.")
break
push_flashpack_model_to_hf(model, hf_repo, log_fn)
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb = encode_fn(prompt)
mapped = model(short_emb.to(device)).cpu()
long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": long_prompt})
return chat
return model, tokenizer, embed_model, enhance_fn, logs
# ===========================
# Lazy Load / Get Model
# ===========================
def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
local_model_path = "model.flashpack"
if os.path.exists(local_model_path):
print("βœ… Loading local model")
else:
try:
files = list_repo_files(hf_repo)
if "model.flashpack" in files:
print("βœ… Downloading model from HF")
local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
else:
print("🚫 No pretrained model found")
return None, None, None, None
except Exception as e:
print(f"⚠️ Error accessing HF: {e}")
return None, None, None, None
model = GemmaTrainer().from_flashpack(local_model_path)
model.eval()
tokenizer, embed_model, encode_fn = build_encoder("gpt2")
@torch.no_grad()
def enhance_fn(prompt, chat):
chat = chat or []
short_emb = encode_fn(prompt).to(device)
mapped = model(short_emb).cpu()
long_prompt = f"🌟 Enhanced prompt: {prompt} (creatively expanded)"
chat.append({"role": "user", "content": prompt})
chat.append({"role": "assistant", "content": long_prompt})
return chat
return model, tokenizer, embed_model, enhance_fn
# ===========================
# Gradio UI
# ===========================
with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort β†’ Long prompt expander")
chatbot = gr.Chatbot(height=400, type="messages")
user_input = gr.Textbox(label="Your prompt")
send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear")
train_btn = gr.Button("🧩 Train Model", variant="secondary")
log_output = gr.Textbox(label="Logs", lines=15)
# Lazy load model
model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
logs = []
if enhance_fn is None:
def enhance_fn(prompt, chat):
chat = chat or []
chat.append({"role": "assistant",
"content": "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
return chat
logs.append("⚠️ No pretrained model found. Ready to train.")
else:
logs.append("βœ… Model loaded β€” ready to enhance.")
# Button callbacks
send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
def retrain():
global model, tokenizer, embed_model, enhance_fn, logs
logs = ["πŸš€ Training model, please wait..."]
model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
logs.extend(train_logs)
return gr.Textbox.update(value="\n".join(logs))
train_btn.click(retrain, None, log_output)
if __name__ == "__main__":
demo.launch(show_error=True)