rahul7star commited on
Commit
9f53409
·
verified ·
1 Parent(s): 3589a82

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +124 -87
app_flash1.py CHANGED
@@ -12,17 +12,17 @@ from huggingface_hub import Repository
12
  from typing import Tuple
13
 
14
  # ============================================================
15
- # 🖥 CPU device setup
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
- # 1️⃣ FlashPack MLP model (CPU-friendly)
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
- def __init__(self, input_dim: int, hidden_dim: int = 512, output_dim: int = 768):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
@@ -38,39 +38,36 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
38
  return x
39
 
40
  # ============================================================
41
- # 2️⃣ Lazy-loading GPT-2 encoder
42
  # ============================================================
43
- _embed_model = None
44
- _tokenizer = None
 
 
 
45
 
46
- def get_encoder(model_name="gpt2", max_length=64):
47
- global _embed_model, _tokenizer
48
- if _embed_model is None or _tokenizer is None:
49
- print("⚡ Loading GPT-2 encoder model...")
50
- _tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- if _tokenizer.pad_token is None:
52
- _tokenizer.pad_token = _tokenizer.eos_token
53
- _embed_model = AutoModel.from_pretrained(model_name).to(device)
54
- _embed_model.eval()
55
- return _tokenizer, _embed_model
56
 
57
- @torch.no_grad()
58
- def encode_prompt(prompt: str) -> torch.Tensor:
59
- tokenizer, embed_model = get_encoder()
60
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
61
- padding="max_length", max_length=64).to(device)
62
- last_hidden = embed_model(**inputs).last_hidden_state
63
- mean_pool = last_hidden.mean(dim=1)
64
- max_pool, _ = last_hidden.max(dim=1)
65
- return torch.cat([mean_pool, max_pool], dim=1).cpu()
 
66
 
67
  # ============================================================
68
- # 3️⃣ Push FlashPack model to Hugging Face Hub
69
  # ============================================================
70
  def push_flashpack_model_to_hf(model, hf_repo: str):
71
  logs = []
72
  with tempfile.TemporaryDirectory() as tmp_dir:
73
- logs.append(f"📂 Using temp dir: {tmp_dir}")
74
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
75
  pack_path = os.path.join(tmp_dir, "model.flashpack")
76
  model.save_flashpack(pack_path, target_dtype=torch.float32)
@@ -86,52 +83,60 @@ def push_flashpack_model_to_hf(model, hf_repo: str):
86
  # ============================================================
87
  def train_flashpack_model(
88
  dataset_name: str = "rahul7star/prompt-enhancer-dataset",
89
- max_encode: int = 500,
90
- hidden_dim: int = 512,
91
- push_to_hub: bool = True,
92
  hf_repo: str = "rahul7star/FlashPack",
93
- early_stop_threshold: float = 0.001
94
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
95
-
96
  print("📦 Loading dataset...")
97
  dataset = load_dataset(dataset_name, split="train")
98
- dataset = dataset.select(range(min(max_encode, len(dataset))))
99
- n_train = int(0.8 * len(dataset))
100
- n_test = len(dataset) - n_train
101
- train_dataset = dataset.select(range(n_train))
102
- test_dataset = dataset.select(range(n_train, len(dataset)))
103
- print(f"⚡ Train: {n_train}, Test: {n_test}")
104
-
105
- # Encode prompts lazily
106
- def batch_encode(ds):
 
 
 
 
 
 
107
  short_list, long_list = [], []
108
  for i, item in enumerate(ds):
109
- short_list.append(encode_prompt(item["short_prompt"]))
110
- long_list.append(encode_prompt(item["long_prompt"]))
111
- if (i+1) % 20 == 0 or (i+1) == len(ds):
112
  print(f" → Encoded {i+1}/{len(ds)} prompts")
113
  gc.collect()
114
  return torch.vstack(short_list), torch.vstack(long_list)
115
 
116
- short_train, long_train = batch_encode(train_dataset)
117
- short_test, long_test = batch_encode(test_dataset)
118
- print(f"✅ Embeddings shapes: short_train={short_train.shape}, long_train={long_train.shape}")
119
 
 
120
  input_dim = short_train.shape[1]
121
  output_dim = long_train.shape[1]
122
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
123
 
 
124
  criterion = nn.CosineSimilarity(dim=1)
125
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
126
  max_epochs = 50
127
- batch_size = 16
 
128
 
 
129
  print("🚀 Training model...")
130
- n = short_train.shape[0]
131
  for epoch in range(max_epochs):
132
  model.train()
133
- perm = torch.randperm(n)
134
  epoch_loss = 0.0
 
135
  for start in range(0, n, batch_size):
136
  idx = perm[start:start+batch_size]
137
  inputs = short_train[idx].to(device)
@@ -144,68 +149,97 @@ def train_flashpack_model(
144
  epoch_loss += loss.item() * inputs.size(0)
145
  epoch_loss /= n
146
 
147
- # Evaluate on test set
148
  model.eval()
149
  with torch.no_grad():
150
- outputs_test = model(short_test.to(device))
151
- test_loss = 1 - criterion(outputs_test, long_test.to(device)).mean().item()
152
-
153
- print(f"Epoch {epoch+1}/{max_epochs} | Train Loss={epoch_loss:.6f} | Test Loss={test_loss:.6f}")
154
 
155
- # Early stop: very low test loss means model is good
156
- if test_loss < early_stop_threshold:
157
- print("🎯 Early stop: test loss below threshold. Model is ready!")
158
  break
159
 
160
- if push_to_hub:
 
 
 
161
  logs = push_flashpack_model_to_hf(model, hf_repo)
162
  for log in logs:
163
  print(log)
 
 
164
 
165
- return model, dataset, None, None, long_train # embed_model and tokenizer lazy-loaded
 
166
 
167
  # ============================================================
168
- # 5️⃣ Lazy load or train
169
  # ============================================================
170
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
171
  try:
172
- print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
173
  model = GemmaTrainer.from_flashpack(hf_repo)
174
  model.eval()
175
- print("✅ Loaded model from HF")
176
- return model
177
  except Exception as e:
178
  print(f"⚠️ Load failed: {e}")
179
- print("⏬ Training new FlashPack model locally...")
180
- model, dataset, _, _, long_embeddings = train_flashpack_model()
181
- return model, dataset, long_embeddings
 
 
 
 
 
182
 
183
  # ============================================================
184
- # 6️⃣ Inference helpers
185
  # ============================================================
186
  @torch.no_grad()
187
- def enhance_prompt(user_prompt: str, chat_history, model, long_embeddings, dataset):
188
- chat_history = chat_history or []
189
- short_emb = encode_prompt(user_prompt)
190
- mapped = model(short_emb.to(device)).cpu()
191
- sims = (long_embeddings @ mapped.t()).squeeze(1)
192
- long_norms = long_embeddings.norm(dim=1)
193
- mapped_norm = mapped.norm()
194
- sims = sims / (long_norms * (mapped_norm + 1e-12))
195
- best_idx = int(sims.argmax().item())
196
- enhanced_prompt = dataset[best_idx]["long_prompt"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
- chat_history.append({"role": "user", "content": user_prompt})
199
- chat_history.append({"role": "assistant", "content": enhanced_prompt})
200
- return chat_history
201
 
202
  # ============================================================
203
- # 7️⃣ Launch Gradio app
204
  # ============================================================
205
- model, dataset, long_embeddings = get_flashpack_model()
206
-
207
  with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
208
- gr.Markdown("# ✨ Prompt Enhancer (FlashPack mapper)\nEnter a short prompt, and it will expand it.")
 
 
 
 
 
 
 
209
  with gr.Row():
210
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
211
  with gr.Column(scale=1):
@@ -213,9 +247,12 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
213
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
214
  clear_btn = gr.Button("🧹 Clear Chat")
215
 
216
- send_btn.click(enhance_prompt, [user_prompt, chatbot, model, long_embeddings, dataset], chatbot)
217
- user_prompt.submit(enhance_prompt, [user_prompt, chatbot, model, long_embeddings, dataset], chatbot)
218
  clear_btn.click(lambda: [], None, chatbot)
219
 
 
 
 
220
  if __name__ == "__main__":
221
  demo.launch(show_error=True)
 
12
  from typing import Tuple
13
 
14
  # ============================================================
15
+ # 🖥 Device setup (CPU-only)
16
  # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
+ # 1️⃣ FlashPack model
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
 
38
  return x
39
 
40
  # ============================================================
41
+ # 2️⃣ Encoder (mean + max pooling)
42
  # ============================================================
43
+ def build_encoder(model_name="gpt2", max_length: int = 128):
44
+ print(f"📦 Loading tokenizer and model for {model_name}...")
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
 
49
+ embed_model = AutoModel.from_pretrained(model_name).to(device)
50
+ embed_model.eval()
51
+ print(f"✅ Encoder ready: {model_name}")
 
 
 
 
 
 
 
52
 
53
+ @torch.no_grad()
54
+ def encode(prompt: str) -> torch.Tensor:
55
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
56
+ padding="max_length", max_length=max_length).to(device)
57
+ last_hidden = embed_model(**inputs).last_hidden_state
58
+ mean_pool = last_hidden.mean(dim=1)
59
+ max_pool, _ = last_hidden.max(dim=1)
60
+ return torch.cat([mean_pool, max_pool], dim=1).cpu() # double dimension
61
+
62
+ return tokenizer, embed_model, encode
63
 
64
  # ============================================================
65
+ # 3️⃣ Push FlashPack model to HF
66
  # ============================================================
67
  def push_flashpack_model_to_hf(model, hf_repo: str):
68
  logs = []
69
  with tempfile.TemporaryDirectory() as tmp_dir:
70
+ logs.append(f"📂 Temporary directory: {tmp_dir}")
71
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
72
  pack_path = os.path.join(tmp_dir, "model.flashpack")
73
  model.save_flashpack(pack_path, target_dtype=torch.float32)
 
83
  # ============================================================
84
  def train_flashpack_model(
85
  dataset_name: str = "rahul7star/prompt-enhancer-dataset",
86
+ max_encode: int = 1000,
87
+ hidden_dim: int = 1024,
 
88
  hf_repo: str = "rahul7star/FlashPack",
89
+ push_to_hub: bool = True
90
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
91
+
92
  print("📦 Loading dataset...")
93
  dataset = load_dataset(dataset_name, split="train")
94
+ limit = min(max_encode, len(dataset))
95
+ dataset = dataset.select(range(limit))
96
+ print(f"⚡ Using {len(dataset)} prompts for training")
97
+
98
+ # Split train/test
99
+ train_size = int(0.9 * len(dataset))
100
+ test_size = len(dataset) - train_size
101
+ dataset_train = dataset.select(range(train_size))
102
+ dataset_test = dataset.select(range(train_size, len(dataset)))
103
+ print(f"🧪 Train/Test split: {len(dataset_train)} / {len(dataset_test)}")
104
+
105
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
106
+
107
+ # Encode embeddings
108
+ def encode_dataset(ds):
109
  short_list, long_list = [], []
110
  for i, item in enumerate(ds):
111
+ short_list.append(encode_fn(item["short_prompt"]))
112
+ long_list.append(encode_fn(item["long_prompt"]))
113
+ if (i+1) % 50 == 0 or (i+1) == len(ds):
114
  print(f" → Encoded {i+1}/{len(ds)} prompts")
115
  gc.collect()
116
  return torch.vstack(short_list), torch.vstack(long_list)
117
 
118
+ short_train, long_train = encode_dataset(dataset_train)
119
+ short_test, long_test = encode_dataset(dataset_test)
120
+ print(f"✅ Embeddings shapes: train {short_train.shape}/{long_train.shape}, test {short_test.shape}/{long_test.shape}")
121
 
122
+ # Build model
123
  input_dim = short_train.shape[1]
124
  output_dim = long_train.shape[1]
125
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
126
 
127
+ # Loss and optimizer
128
  criterion = nn.CosineSimilarity(dim=1)
129
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
130
  max_epochs = 50
131
+ batch_size = 32
132
+ n = short_train.shape[0]
133
 
134
+ # Training loop
135
  print("🚀 Training model...")
 
136
  for epoch in range(max_epochs):
137
  model.train()
 
138
  epoch_loss = 0.0
139
+ perm = torch.randperm(n)
140
  for start in range(0, n, batch_size):
141
  idx = perm[start:start+batch_size]
142
  inputs = short_train[idx].to(device)
 
149
  epoch_loss += loss.item() * inputs.size(0)
150
  epoch_loss /= n
151
 
152
+ # Evaluate on test
153
  model.eval()
154
  with torch.no_grad():
155
+ test_outputs = model(short_test.to(device))
156
+ test_loss = 1 - criterion(test_outputs, long_test.to(device)).mean()
157
+ if epoch % 5 == 0 or epoch == max_epochs-1:
158
+ print(f"Epoch {epoch+1}/{max_epochs} Train Loss: {epoch_loss:.6f}, Test Loss: {test_loss:.6f}")
159
 
160
+ # Auto stop if test_loss is very small
161
+ if test_loss < 0.01:
162
+ print("🎯 Test loss very low early stopping!")
163
  break
164
 
165
+ print("✅ Training finished!")
166
+
167
+ # Push to HF if training good
168
+ if push_to_hub and test_loss < 0.05:
169
  logs = push_flashpack_model_to_hf(model, hf_repo)
170
  for log in logs:
171
  print(log)
172
+ else:
173
+ print("⚠️ Model not pushed — test loss not low enough.")
174
 
175
+ # Return
176
+ return model, dataset, embed_model, tokenizer, long_train
177
 
178
  # ============================================================
179
+ # 5️⃣ Lazy-load model
180
  # ============================================================
181
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
182
  try:
183
+ print(f"🔁 Loading FlashPack model from {hf_repo}...")
184
  model = GemmaTrainer.from_flashpack(hf_repo)
185
  model.eval()
186
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
187
+ return model, tokenizer, embed_model
188
  except Exception as e:
189
  print(f"⚠️ Load failed: {e}")
190
+ print("⏬ Training new model locally...")
191
+ model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
192
+ return model, tokenizer, embed_model, dataset, long_embeddings
193
+
194
+ # ============================================================
195
+ # 6️⃣ Initialize model
196
+ # ============================================================
197
+ model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
198
 
199
  # ============================================================
200
+ # 7️⃣ Inference helpers (closure for Gradio)
201
  # ============================================================
202
  @torch.no_grad()
203
+ def encode_prompt(prompt: str) -> torch.Tensor:
204
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
205
+ padding="max_length", max_length=128).to(device)
206
+ last_hidden = embed_model(**inputs).last_hidden_state
207
+ mean_pool = last_hidden.mean(dim=1)
208
+ max_pool, _ = last_hidden.max(dim=1)
209
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
210
+
211
+ def make_enhance_fn(model, long_embeddings, dataset):
212
+ @torch.no_grad()
213
+ def fn(user_prompt, chat_history):
214
+ chat_history = chat_history or []
215
+ short_emb = encode_prompt(user_prompt)
216
+ mapped = model(short_emb.to(device)).cpu()
217
+ sims = (long_embeddings @ mapped.t()).squeeze(1)
218
+ long_norms = long_embeddings.norm(dim=1)
219
+ mapped_norm = mapped.norm()
220
+ sims = sims / (long_norms * (mapped_norm + 1e-12))
221
+ best_idx = int(sims.argmax().item())
222
+ enhanced_prompt = dataset[best_idx]["long_prompt"]
223
+
224
+ chat_history.append({"role": "user", "content": user_prompt})
225
+ chat_history.append({"role": "assistant", "content": enhanced_prompt})
226
+ return chat_history
227
+ return fn
228
 
229
+ enhance_fn = make_enhance_fn(model, long_embeddings, dataset)
 
 
230
 
231
  # ============================================================
232
+ # 8️⃣ Gradio UI
233
  # ============================================================
 
 
234
  with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
235
+ gr.Markdown(
236
+ """
237
+ # ✨ Prompt Enhancer (FlashPack mapper)
238
+ Enter a short prompt, and the model will **expand it with details and creative context**.
239
+ (CPU-only mode.)
240
+ """
241
+ )
242
+
243
  with gr.Row():
244
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
245
  with gr.Column(scale=1):
 
247
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
248
  clear_btn = gr.Button("🧹 Clear Chat")
249
 
250
+ send_btn.click(enhance_fn, [user_prompt, chatbot], chatbot)
251
+ user_prompt.submit(enhance_fn, [user_prompt, chatbot], chatbot)
252
  clear_btn.click(lambda: [], None, chatbot)
253
 
254
+ # ============================================================
255
+ # 9️⃣ Launch
256
+ # ============================================================
257
  if __name__ == "__main__":
258
  demo.launch(show_error=True)