rahul7star commited on
Commit
d9e93e9
·
verified ·
1 Parent(s): 7bef75a

Create app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +253 -0
app_flash1.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import tempfile
7
+ import gradio as gr
8
+ from datasets import load_dataset
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from flashpack import FlashPackMixin
11
+ from huggingface_hub import Repository
12
+ from typing import Tuple
13
+ from sklearn.model_selection import train_test_split
14
+
15
+ device = torch.device("cpu")
16
+ torch.set_num_threads(4)
17
+ print(f"🔧 Using device: {device} (CPU-only)")
18
+
19
+ # ============================================================
20
+ # 1️⃣ Model
21
+ # ============================================================
22
+ class GemmaTrainer(nn.Module, FlashPackMixin):
23
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
24
+ super().__init__()
25
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
26
+ self.relu = nn.ReLU()
27
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
28
+ self.fc3 = nn.Linear(hidden_dim, output_dim)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ x = self.fc1(x)
32
+ x = self.relu(x)
33
+ x = self.fc2(x)
34
+ x = self.relu(x)
35
+ x = self.fc3(x)
36
+ return x
37
+
38
+ # ============================================================
39
+ # 2️⃣ Encoder with batch mean+max pooling
40
+ # ============================================================
41
+ def build_encoder(model_name="gpt2", max_length=128):
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
43
+ if tokenizer.pad_token is None:
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ embed_model = AutoModel.from_pretrained(model_name).to(device)
47
+ embed_model.eval()
48
+
49
+ @torch.no_grad()
50
+ def encode_batch(prompts: list, batch_size=16) -> torch.Tensor:
51
+ embeddings = []
52
+ for i in range(0, len(prompts), batch_size):
53
+ batch = prompts[i:i+batch_size]
54
+ inputs = tokenizer(batch, return_tensors="pt", truncation=True,
55
+ padding="max_length", max_length=max_length).to(device)
56
+ last_hidden = embed_model(**inputs).last_hidden_state
57
+ mean_pool = last_hidden.mean(dim=1)
58
+ max_pool, _ = last_hidden.max(dim=1)
59
+ batch_emb = torch.cat([mean_pool, max_pool], dim=1)
60
+ embeddings.append(batch_emb.cpu())
61
+ return torch.vstack(embeddings)
62
+
63
+ return tokenizer, embed_model, encode_batch
64
+
65
+ # ============================================================
66
+ # 3️⃣ Push model to HF
67
+ # ============================================================
68
+ def push_flashpack_model_to_hf(model, hf_repo: str):
69
+ logs = []
70
+ with tempfile.TemporaryDirectory() as tmp_dir:
71
+ logs.append(f"📂 Using temporary directory: {tmp_dir}")
72
+ repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
73
+ pack_path = os.path.join(tmp_dir, "model.flashpack")
74
+ model.save_flashpack(pack_path, target_dtype=torch.float32)
75
+ readme_path = os.path.join(tmp_dir, "README.md")
76
+ with open(readme_path, "w") as f:
77
+ f.write("# FlashPack Model\nThis repo contains a FlashPack model trained for short→long prompt mapping.")
78
+ repo.push_to_hub()
79
+ logs.append(f"✅ Model pushed to Hugging Face repo: {hf_repo}")
80
+ return logs
81
+
82
+ # ============================================================
83
+ # 4️⃣ Train with train/test split & detailed logging
84
+ # ============================================================
85
+ def train_flashpack_model(
86
+ dataset_name="rahul7star/prompt-enhancer-dataset",
87
+ max_encode=1000,
88
+ hidden_dim=1024,
89
+ hf_repo="rahul7star/FlashPack",
90
+ push_to_hub=True,
91
+ test_split=0.1,
92
+ batch_size=32,
93
+ max_epochs=50,
94
+ target_test_loss=0.01
95
+ ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
96
+
97
+ print("📦 Loading dataset...")
98
+ dataset = load_dataset(dataset_name, split="train")
99
+ limit = min(max_encode, len(dataset))
100
+ dataset = dataset.select(range(limit))
101
+ print(f"⚡ Using {len(dataset)} prompts for training")
102
+
103
+ short_prompts = [item["short_prompt"] for item in dataset]
104
+ long_prompts = [item["long_prompt"] for item in dataset]
105
+
106
+ # Split
107
+ train_short, test_short, train_long, test_long = train_test_split(
108
+ short_prompts, long_prompts, test_size=test_split, random_state=42
109
+ )
110
+ print(f"🔹 Train size: {len(train_short)}, Test size: {len(test_short)}")
111
+
112
+ tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
113
+
114
+ # Encode
115
+ print("⚡ Encoding training prompts...")
116
+ train_short_emb = encode_batch(train_short)
117
+ train_long_emb = encode_batch(train_long)
118
+ print(f"✅ Train embeddings shape: {train_short_emb.shape}, {train_long_emb.shape}")
119
+
120
+ print("⚡ Encoding test prompts...")
121
+ test_short_emb = encode_batch(test_short)
122
+ test_long_emb = encode_batch(test_long)
123
+ print(f"✅ Test embeddings shape: {test_short_emb.shape}, {test_long_emb.shape}")
124
+
125
+ input_dim = train_short_emb.shape[1]
126
+ output_dim = train_long_emb.shape[1]
127
+
128
+ model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
129
+
130
+ criterion = nn.CosineSimilarity(dim=1)
131
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
132
+
133
+ n_train = train_short_emb.shape[0]
134
+
135
+ print("🚀 Training model...")
136
+ for epoch in range(max_epochs):
137
+ model.train()
138
+ epoch_loss = 0.0
139
+ perm = torch.randperm(n_train)
140
+ for start in range(0, n_train, batch_size):
141
+ idx = perm[start:start+batch_size]
142
+ inputs = train_short_emb[idx].to(device)
143
+ targets = train_long_emb[idx].to(device)
144
+
145
+ optimizer.zero_grad()
146
+ outputs = model(inputs)
147
+ loss = 1 - criterion(outputs, targets).mean()
148
+ loss.backward()
149
+ optimizer.step()
150
+ epoch_loss += loss.item() * inputs.size(0)
151
+
152
+ epoch_loss /= n_train
153
+
154
+ # Evaluate on test
155
+ model.eval()
156
+ with torch.no_grad():
157
+ test_outputs = model(test_short_emb.to(device))
158
+ test_loss = (1 - criterion(test_outputs, test_long_emb.to(device)).mean()).item()
159
+
160
+ print(f"Epoch {epoch+1}/{max_epochs} → Train loss: {epoch_loss:.6f}, Test loss: {test_loss:.6f}")
161
+
162
+ # Check if model is perfect enough
163
+ if test_loss <= target_test_loss:
164
+ print(f"✅ Target test loss reached ({test_loss:.6f}) – stopping training early.")
165
+ break
166
+
167
+ # Push to HF if trained well
168
+ logs = []
169
+ if push_to_hub and test_loss <= target_test_loss:
170
+ logs = push_flashpack_model_to_hf(model, hf_repo)
171
+ for log in logs:
172
+ print(log)
173
+ elif push_to_hub:
174
+ print(f"⚠️ Test loss too high ({test_loss:.6f}); skipping HF upload.")
175
+
176
+ return model, dataset, embed_model, tokenizer, train_long_emb
177
+
178
+ # ============================================================
179
+ # 5️⃣ Load or train
180
+ # ============================================================
181
+ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
182
+ try:
183
+ print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
184
+ model = GemmaTrainer.from_flashpack(hf_repo)
185
+ model.eval()
186
+ tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
187
+ return model, tokenizer, embed_model
188
+ except Exception as e:
189
+ print(f"⚠️ Load failed: {e}")
190
+ print("⏬ Training a new FlashPack model locally...")
191
+ return train_flashpack_model(hf_repo=hf_repo)
192
+
193
+ # ============================================================
194
+ # 6️⃣ Load or train
195
+ # ============================================================
196
+ model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
197
+
198
+ # ============================================================
199
+ # 7️⃣ Inference helpers
200
+ # ============================================================
201
+ @torch.no_grad()
202
+ def encode_for_inference(prompt: str) -> torch.Tensor:
203
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
204
+ padding="max_length", max_length=128).to(device)
205
+ last_hidden = embed_model(**inputs).last_hidden_state
206
+ mean_pool = last_hidden.mean(dim=1)
207
+ max_pool, _ = last_hidden.max(dim=1)
208
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
209
+
210
+ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
211
+ chat_history = chat_history or []
212
+ short_emb = encode_for_inference(user_prompt)
213
+ mapped = model(short_emb.to(device)).cpu()
214
+
215
+ sims = (long_embeddings @ mapped.t()).squeeze(1)
216
+ long_norms = long_embeddings.norm(dim=1)
217
+ mapped_norm = mapped.norm()
218
+ sims = sims / (long_norms * (mapped_norm + 1e-12))
219
+
220
+ best_idx = int(sims.argmax().item())
221
+ enhanced_prompt = dataset[best_idx]["long_prompt"]
222
+
223
+ chat_history.append({"role": "user", "content": user_prompt})
224
+ chat_history.append({"role": "assistant", "content": enhanced_prompt})
225
+ return chat_history
226
+
227
+ # ============================================================
228
+ # 8️⃣ Gradio UI
229
+ # ============================================================
230
+ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
231
+ gr.Markdown(
232
+ """
233
+ # ✨ Prompt Enhancer (FlashPack mapper)
234
+ Enter a short prompt, and the model will **expand it with details and creative context**.
235
+ (CPU-only mode.)
236
+ """
237
+ )
238
+
239
+ with gr.Row():
240
+ chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
241
+ with gr.Column(scale=1):
242
+ user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
243
+ temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
244
+ max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
245
+ send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
246
+ clear_btn = gr.Button("🧹 Clear Chat")
247
+
248
+ send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
249
+ user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
250
+ clear_btn.click(lambda: [], None, chatbot)
251
+
252
+ if __name__ == "__main__":
253
+ demo.launch(show_error=True)