rahul7star commited on
Commit
f3268bd
·
verified ·
1 Parent(s): 25b932e

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +111 -178
app_flash1.py CHANGED
@@ -1,5 +1,5 @@
1
- import gc
2
  import os
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
@@ -8,18 +8,18 @@ import gradio as gr
8
  from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
- from huggingface_hub import Repository
12
  from typing import Tuple
13
 
14
  # ============================================================
15
- # 🖥 Device setup (CPU-only)
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
- # 1️⃣ FlashPack model
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
  def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
@@ -29,7 +29,7 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
29
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
30
  self.fc3 = nn.Linear(hidden_dim, output_dim)
31
 
32
- def forward(self, x: torch.Tensor) -> torch.Tensor:
33
  x = self.fc1(x)
34
  x = self.relu(x)
35
  x = self.fc2(x)
@@ -37,241 +37,174 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
37
  x = self.fc3(x)
38
  return x
39
 
 
40
  # ============================================================
41
- # 2️⃣ Encoder (mean + max pooling)
42
  # ============================================================
43
- def build_encoder(model_name="gpt2", max_length: int = 128):
44
- print(f"📦 Loading tokenizer and model for {model_name}...")
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  if tokenizer.pad_token is None:
47
  tokenizer.pad_token = tokenizer.eos_token
48
-
49
  embed_model = AutoModel.from_pretrained(model_name).to(device)
50
  embed_model.eval()
51
- print(f"✅ Encoder ready: {model_name}")
52
 
53
  @torch.no_grad()
54
  def encode(prompt: str) -> torch.Tensor:
55
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
56
  padding="max_length", max_length=max_length).to(device)
57
- last_hidden = embed_model(**inputs).last_hidden_state
58
- mean_pool = last_hidden.mean(dim=1)
59
- max_pool, _ = last_hidden.max(dim=1)
60
- return torch.cat([mean_pool, max_pool], dim=1).cpu() # double dimension
61
 
62
  return tokenizer, embed_model, encode
63
 
 
64
  # ============================================================
65
- # 3️⃣ Push FlashPack model to HF
66
  # ============================================================
67
- def push_flashpack_model_to_hf(model, hf_repo: str):
68
- logs = []
69
  with tempfile.TemporaryDirectory() as tmp_dir:
70
- logs.append(f"📂 Temporary directory: {tmp_dir}")
71
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
72
- pack_path = os.path.join(tmp_dir, "model.flashpack")
73
- model.save_flashpack(pack_path, target_dtype=torch.float32)
74
- readme_path = os.path.join(tmp_dir, "README.md")
75
- with open(readme_path, "w") as f:
76
- f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
77
  repo.push_to_hub()
78
- logs.append(f"✅ Model pushed to HF: {hf_repo}")
79
- return logs
80
 
81
  # ============================================================
82
- # 4️⃣ Train FlashPack model
83
  # ============================================================
84
- def train_flashpack_model(
85
- dataset_name: str = "rahul7star/prompt-enhancer-dataset",
86
- max_encode: int = 1000,
87
- hidden_dim: int = 1024,
88
- hf_repo: str = "rahul7star/FlashPack",
89
- push_to_hub: bool = True
90
- ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
91
-
92
  print("📦 Loading dataset...")
93
- dataset = load_dataset(dataset_name, split="train")
94
- limit = min(max_encode, len(dataset))
95
- dataset = dataset.select(range(limit))
96
- print(f"⚡ Using {len(dataset)} prompts for training")
97
-
98
- # Split train/test
99
- train_size = int(0.9 * len(dataset))
100
- dataset_train = dataset.select(range(train_size))
101
- dataset_test = dataset.select(range(train_size, len(dataset)))
102
- print(f"🧪 Train/Test split: {len(dataset_train)} / {len(dataset_test)}")
103
 
104
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
105
 
106
  def encode_dataset(ds):
107
- short_list, long_list = [], []
108
  for i, item in enumerate(ds):
109
- short_list.append(encode_fn(item["short_prompt"]))
110
- long_list.append(encode_fn(item["long_prompt"]))
111
- if (i+1) % 50 == 0 or (i+1) == len(ds):
112
- print(f" → Encoded {i+1}/{len(ds)} prompts")
113
  gc.collect()
114
- return torch.vstack(short_list), torch.vstack(long_list)
115
-
116
- short_train, long_train = encode_dataset(dataset_train)
117
- short_test, long_test = encode_dataset(dataset_test)
118
- print(f"✅ Embeddings shapes: train {short_train.shape}/{long_train.shape}, test {short_test.shape}/{long_test.shape}")
119
 
120
- input_dim = short_train.shape[1]
121
- output_dim = long_train.shape[1]
122
- model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
123
 
124
- criterion = nn.CosineSimilarity(dim=1)
125
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
126
- max_epochs = 50
127
- batch_size = 32
128
- n = short_train.shape[0]
129
 
130
  print("🚀 Training model...")
131
- for epoch in range(max_epochs):
132
  model.train()
133
- epoch_loss = 0.0
134
- perm = torch.randperm(n)
135
- for start in range(0, n, batch_size):
136
- idx = perm[start:start+batch_size]
137
- inputs = short_train[idx].to(device)
138
- targets = long_train[idx].to(device)
139
- optimizer.zero_grad()
140
- outputs = model(inputs)
141
- loss = 1 - criterion(outputs, targets).mean()
142
- loss.backward()
143
- optimizer.step()
144
- epoch_loss += loss.item() * inputs.size(0)
145
- epoch_loss /= n
146
-
147
- model.eval()
148
- with torch.no_grad():
149
- test_outputs = model(short_test.to(device))
150
- test_loss = 1 - criterion(test_outputs, long_test.to(device)).mean()
151
-
152
- if epoch % 5 == 0 or epoch == max_epochs - 1:
153
- print(f"Epoch {epoch+1}/{max_epochs} — Train Loss: {epoch_loss:.6f}, Test Loss: {test_loss:.6f}")
154
-
155
- if test_loss < 0.01:
156
- print("🎯 Early stopping — loss threshold reached.")
157
  break
158
 
159
- print("✅ Training complete!")
160
-
161
- if push_to_hub and test_loss < 0.05:
162
- push_flashpack_model_to_hf(model, hf_repo)
163
- else:
164
- print("⚠️ Model not pushed — test loss not low enough.")
165
-
166
- return model, dataset, embed_model, tokenizer, long_train
167
 
168
 
169
  # ============================================================
170
- # 5️⃣ Load or train
171
  # ============================================================
172
- def get_flashpack_model(hf_repo="rahul7star/FlashPack", input_dim=1536, output_dim=1536):
 
173
  try:
174
- print(f"🔁 Loading FlashPack model from {hf_repo}...")
175
- dummy_model = GemmaTrainer(input_dim=input_dim, hidden_dim=1024, output_dim=output_dim)
176
- model = GemmaTrainer.from_flashpack(hf_repo, model=dummy_model)
177
- model.eval()
178
- print("✅ Model loaded successfully.")
179
- tokenizer, embed_model, _ = build_encoder("gpt2", max_length=128)
180
- return model, tokenizer, embed_model, None, None
 
 
 
 
 
181
  except Exception as e:
182
- print(f"⚠️ Load failed: {e}")
183
- print("⏬ Training new model...")
184
- return train_flashpack_model()
 
185
 
186
  # ============================================================
187
- # 6️⃣ Encode and Enhance
188
  # ============================================================
189
  @torch.no_grad()
190
  def encode_prompt(prompt, tokenizer, embed_model):
191
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
192
  padding="max_length", max_length=128).to(device)
193
- last_hidden = embed_model(**inputs).last_hidden_state
194
- mean_pool = last_hidden.mean(dim=1)
195
- max_pool, _ = last_hidden.max(dim=1)
196
  return torch.cat([mean_pool, max_pool], dim=1).cpu()
197
 
198
- def make_enhance_fn(model, long_embeddings, dataset, tokenizer, embed_model):
 
199
  @torch.no_grad()
200
- def fn(user_prompt, chat_history):
201
- if model is None or dataset is None or long_embeddings is None:
202
- return [{"role": "system", "content": "⚠️ Model not loaded. Please train or reload first."}]
203
- chat_history = chat_history or []
204
- short_emb = encode_prompt(user_prompt, tokenizer, embed_model)
205
  mapped = model(short_emb.to(device)).cpu()
206
- sims = (long_embeddings @ mapped.t()).squeeze(1)
207
- long_norms = long_embeddings.norm(dim=1)
208
- mapped_norm = mapped.norm()
209
- sims = sims / (long_norms * (mapped_norm + 1e-12))
210
- best_idx = int(sims.argmax().item())
211
- enhanced_prompt = dataset[best_idx]["long_prompt"]
212
- chat_history.append({"role": "user", "content": user_prompt})
213
- chat_history.append({"role": "assistant", "content": enhanced_prompt})
214
- return chat_history
215
  return fn
216
 
 
217
  # ============================================================
218
- # 7️⃣ Gradio UI with Train Button
219
  # ============================================================
220
- with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
221
- gr.Markdown("# Prompt Enhancer (FlashPack mapper)")
222
 
223
- status = gr.Textbox(value="Model loading...", label="Status", interactive=False)
224
-
225
- chatbot = gr.Chatbot(label="Enhanced Prompts", type="messages", height=400)
226
- user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
227
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
228
- train_btn = gr.Button("🧠 Train Model")
229
- clear_btn = gr.Button("🧹 Clear Chat")
230
-
231
- model_state = gr.State(None)
232
- tokenizer_state = gr.State(None)
233
- embed_model_state = gr.State(None)
234
- dataset_state = gr.State(None)
235
- long_embeddings_state = gr.State(None)
236
-
237
- def train_model():
238
- status_text = "🔄 Training or loading model..."
239
- model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
240
- status_text = "✅ Model ready!"
241
- return model, tokenizer, embed_model, dataset, long_embeddings, status_text
242
-
243
- train_btn.click(
244
- train_model,
245
- inputs=[],
246
- outputs=[model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state, status],
247
- )
248
-
249
- def enhance(user_prompt, chat_history, model, tokenizer, embed_model, dataset, long_embeddings):
250
- fn = make_enhance_fn(model, long_embeddings, dataset, tokenizer, embed_model)
251
- return fn(user_prompt, chat_history)
252
-
253
- send_btn.click(
254
- enhance,
255
- inputs=[user_prompt, chatbot, model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state],
256
- outputs=chatbot,
257
- )
258
- user_prompt.submit(
259
- enhance,
260
- inputs=[user_prompt, chatbot, model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state],
261
- outputs=chatbot,
262
- )
263
 
264
- clear_btn.click(lambda: [], None, chatbot)
265
 
266
- # Auto-load model on startup
267
- demo.load(
268
- lambda: train_model(),
269
- inputs=None,
270
- outputs=[model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state, status],
271
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- # ============================================================
274
- # 9️⃣ Launch
275
- # ============================================================
276
  if __name__ == "__main__":
277
  demo.launch(show_error=True)
 
 
1
  import os
2
+ import gc
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
 
8
  from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
+ from huggingface_hub import Repository, list_repo_files, hf_hub_download
12
  from typing import Tuple
13
 
14
  # ============================================================
15
+ # 🖥 Device Setup
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
+ # 1️⃣ Model Definition
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
  def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
 
29
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
30
  self.fc3 = nn.Linear(hidden_dim, output_dim)
31
 
32
+ def forward(self, x: torch.Tensor):
33
  x = self.fc1(x)
34
  x = self.relu(x)
35
  x = self.fc2(x)
 
37
  x = self.fc3(x)
38
  return x
39
 
40
+
41
  # ============================================================
42
+ # 2️⃣ Encoder Setup
43
  # ============================================================
44
+ def build_encoder(model_name="gpt2", max_length=128):
45
+ print(f"📦 Loading encoder: {model_name}")
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
  if tokenizer.pad_token is None:
48
  tokenizer.pad_token = tokenizer.eos_token
 
49
  embed_model = AutoModel.from_pretrained(model_name).to(device)
50
  embed_model.eval()
 
51
 
52
  @torch.no_grad()
53
  def encode(prompt: str) -> torch.Tensor:
54
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
55
  padding="max_length", max_length=max_length).to(device)
56
+ hidden = embed_model(**inputs).last_hidden_state
57
+ mean_pool = hidden.mean(dim=1)
58
+ max_pool, _ = hidden.max(dim=1)
59
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
60
 
61
  return tokenizer, embed_model, encode
62
 
63
+
64
  # ============================================================
65
+ # 3️⃣ Push to Hugging Face
66
  # ============================================================
67
+ def push_flashpack_model_to_hf(model, hf_repo):
 
68
  with tempfile.TemporaryDirectory() as tmp_dir:
 
69
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
70
+ model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
71
+ with open(os.path.join(tmp_dir, "README.md"), "w") as f:
72
+ f.write("# FlashPack Model\nTrained locally and pushed to HF.")
 
 
73
  repo.push_to_hub()
74
+ print(f"✅ Model pushed to {hf_repo}")
75
+
76
 
77
  # ============================================================
78
+ # 4️⃣ Training Logic
79
  # ============================================================
80
+ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
81
+ hf_repo="rahul7star/FlashPack",
82
+ max_encode=1000):
 
 
 
 
 
83
  print("📦 Loading dataset...")
84
+ dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
85
+ print(f"✅ Loaded {len(dataset)} samples")
 
 
 
 
 
 
 
 
86
 
87
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2")
88
 
89
  def encode_dataset(ds):
90
+ s_list, l_list = [], []
91
  for i, item in enumerate(ds):
92
+ s_list.append(encode_fn(item["short_prompt"]))
93
+ l_list.append(encode_fn(item["long_prompt"]))
94
+ if (i + 1) % 50 == 0:
95
+ print(f" → Encoded {i + 1}/{len(ds)}")
96
  gc.collect()
97
+ return torch.vstack(s_list), torch.vstack(l_list)
 
 
 
 
98
 
99
+ short_emb, long_emb = encode_dataset(dataset)
100
+ input_dim, output_dim = short_emb.shape[1], long_emb.shape[1]
101
+ model = GemmaTrainer(input_dim, 1024, output_dim)
102
 
 
103
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
104
+ loss_fn = nn.CosineSimilarity(dim=1)
 
 
105
 
106
  print("🚀 Training model...")
107
+ for epoch in range(20):
108
  model.train()
109
+ optimizer.zero_grad()
110
+ preds = model(short_emb)
111
+ loss = 1 - loss_fn(preds, long_emb).mean()
112
+ loss.backward()
113
+ optimizer.step()
114
+ print(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
115
+ if loss.item() < 0.01:
116
+ print("🎯 Early stopping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  break
118
 
119
+ push_flashpack_model_to_hf(model, hf_repo)
120
+ return model, tokenizer, embed_model, dataset, long_emb
 
 
 
 
 
 
121
 
122
 
123
  # ============================================================
124
+ # 5️⃣ Load or Train
125
  # ============================================================
126
+ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
127
+ print(f"🔍 Checking for model in repo: {hf_repo}")
128
  try:
129
+ files = list_repo_files(hf_repo)
130
+ if "model.flashpack" in files:
131
+ print("✅ Found model.flashpack in repo — downloading and loading it.")
132
+ local_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
133
+ dummy = GemmaTrainer(1536, 1024, 1536)
134
+ model = GemmaTrainer.from_flashpack(local_path, model=dummy)
135
+ model.eval()
136
+ tokenizer, embed_model, _ = build_encoder("gpt2")
137
+ return model, tokenizer, embed_model, None, None
138
+ else:
139
+ print("🚫 model.flashpack not found — starting training.")
140
+ return train_flashpack_model(hf_repo=hf_repo)
141
  except Exception as e:
142
+ print(f"⚠️ Error checking repo: {e}")
143
+ print("⏬ Training new model instead.")
144
+ return train_flashpack_model(hf_repo=hf_repo)
145
+
146
 
147
  # ============================================================
148
+ # 6️⃣ Encode & Enhance Functions
149
  # ============================================================
150
  @torch.no_grad()
151
  def encode_prompt(prompt, tokenizer, embed_model):
152
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
153
  padding="max_length", max_length=128).to(device)
154
+ hidden = embed_model(**inputs).last_hidden_state
155
+ mean_pool = hidden.mean(dim=1)
156
+ max_pool, _ = hidden.max(dim=1)
157
  return torch.cat([mean_pool, max_pool], dim=1).cpu()
158
 
159
+
160
+ def make_enhance_fn(model, tokenizer, embed_model, long_emb, dataset):
161
  @torch.no_grad()
162
+ def fn(prompt, chat):
163
+ chat = chat or []
164
+ short_emb = encode_prompt(prompt, tokenizer, embed_model)
 
 
165
  mapped = model(short_emb.to(device)).cpu()
166
+ sims = (long_emb @ mapped.t()).squeeze(1)
167
+ best = int(sims.argmax())
168
+ enhanced = dataset[best]["long_prompt"]
169
+ chat.append({"role": "user", "content": prompt})
170
+ chat.append({"role": "assistant", "content": enhanced})
171
+ return chat
 
 
 
172
  return fn
173
 
174
+
175
  # ============================================================
176
+ # 7️⃣ Gradio UI
177
  # ============================================================
178
+ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
179
+ gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort → Long prompt expander")
180
 
181
+ chatbot = gr.Chatbot(height=400)
182
+ user_input = gr.Textbox(label="Your prompt")
 
 
183
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
184
+ clear_btn = gr.Button("🧹 Clear")
185
+ train_btn = gr.Button("🧩 Train Model", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ status = gr.Markdown("Status: Ready")
188
 
189
+ # Load model initially
190
+ model, tokenizer, embed_model, dataset, long_emb = get_flashpack_model()
191
+ enhance_fn = make_enhance_fn(model, tokenizer, embed_model, long_emb, dataset) if dataset else None
192
+
193
+ def enhance(prompt, chat):
194
+ if not enhance_fn:
195
+ return chat + [{"role": "assistant", "content": "⚠️ Model not ready. Please train first."}]
196
+ return enhance_fn(prompt, chat)
197
+
198
+ def retrain():
199
+ global model, tokenizer, embed_model, dataset, long_emb, enhance_fn
200
+ model, tokenizer, embed_model, dataset, long_emb = train_flashpack_model()
201
+ enhance_fn = make_enhance_fn(model, tokenizer, embed_model, long_emb, dataset)
202
+ return "✅ Model retrained and pushed to HF!"
203
+
204
+ send_btn.click(enhance, [user_input, chatbot], chatbot)
205
+ user_input.submit(enhance, [user_input, chatbot], chatbot)
206
+ clear_btn.click(lambda: [], None, chatbot)
207
+ train_btn.click(retrain, None, status)
208
 
 
 
 
209
  if __name__ == "__main__":
210
  demo.launch(show_error=True)