rahul7star commited on
Commit
25b932e
Β·
verified Β·
1 Parent(s): 31737db

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +202 -86
app_flash1.py CHANGED
@@ -1,139 +1,232 @@
1
- import gradio as gr
 
2
  import torch
 
 
 
 
3
  from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModel
 
5
  from huggingface_hub import Repository
6
- import tempfile, os, gc, torch.nn as nn, torch.optim as optim
 
 
 
 
 
 
 
7
 
8
  # ============================================================
9
- # GemmaTrainer simple model
10
  # ============================================================
11
- class GemmaTrainer(nn.Module):
12
- def __init__(self, input_dim, hidden_dim, output_dim):
13
  super().__init__()
14
  self.fc1 = nn.Linear(input_dim, hidden_dim)
15
  self.relu = nn.ReLU()
16
- self.fc2 = nn.Linear(hidden_dim, output_dim)
 
17
 
18
- def forward(self, x):
19
- return self.fc2(self.relu(self.fc1(x)))
20
-
21
- def save_flashpack(self, path, target_dtype=torch.float32):
22
- torch.save(self.state_dict(), path)
23
-
24
- @classmethod
25
- def from_flashpack(cls, repo_path, model=None):
26
- local_path = os.path.expanduser(
27
- f"~/.cache/huggingface/hub/models--{repo_path.replace('/', '--')}/snapshots/"
28
- )
29
- # Find the newest snapshot
30
- for root, dirs, files in os.walk(local_path):
31
- if "model.flashpack" in files:
32
- file_path = os.path.join(root, "model.flashpack")
33
- model.load_state_dict(torch.load(file_path, map_location="cpu"))
34
- return model
35
- raise FileNotFoundError("model.flashpack not found in repo cache")
36
 
37
  # ============================================================
38
- # Build encoder (tokenizer + model)
39
  # ============================================================
40
- def build_encoder(model_name="gpt2", max_length: int = 32):
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  if tokenizer.pad_token is None:
43
  tokenizer.pad_token = tokenizer.eos_token
44
- embed_model = AutoModel.from_pretrained(model_name)
 
45
  embed_model.eval()
 
46
 
47
  @torch.no_grad()
48
  def encode(prompt: str) -> torch.Tensor:
49
- inputs = tokenizer(
50
- prompt,
51
- return_tensors="pt",
52
- truncation=True,
53
- padding="max_length",
54
- max_length=max_length
55
- )
56
- outputs = embed_model(**inputs).last_hidden_state.mean(dim=1)
57
- return outputs.cpu()
58
 
59
  return tokenizer, embed_model, encode
60
 
61
  # ============================================================
62
- # Train FlashPack model (only if not found)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # ============================================================
64
  def train_flashpack_model(
65
  dataset_name: str = "rahul7star/prompt-enhancer-dataset",
66
- max_encode: int = 500,
67
- device: str = "cpu"
68
- ):
 
 
 
69
  print("πŸ“¦ Loading dataset...")
70
  dataset = load_dataset(dataset_name, split="train")
71
- dataset = dataset.select(range(min(max_encode, len(dataset))))
 
 
72
 
73
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
 
 
 
 
74
 
75
- short_list, long_list = [], []
76
- for i, item in enumerate(dataset):
77
- short_list.append(encode_fn(item["short_prompt"]))
78
- long_list.append(encode_fn(item["long_prompt"]))
79
- if (i + 1) % 50 == 0:
80
- print(f" β†’ Encoded {i+1}/{len(dataset)}")
81
 
82
- short_embeddings = torch.vstack(short_list)
83
- long_embeddings = torch.vstack(long_list)
 
 
 
 
 
 
 
84
 
85
- input_dim = short_embeddings.shape[1]
86
- output_dim = long_embeddings.shape[1]
87
- model = GemmaTrainer(input_dim, 512, output_dim)
88
 
89
- criterion = nn.MSELoss()
 
 
 
 
90
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
 
 
 
91
 
92
  print("πŸš€ Training model...")
93
- for epoch in range(10):
94
  model.train()
95
- optimizer.zero_grad()
96
- outputs = model(short_embeddings)
97
- loss = criterion(outputs, long_embeddings)
98
- loss.backward()
99
- optimizer.step()
100
- print(f"Epoch {epoch+1}/10 - Loss: {loss.item():.6f}")
 
 
 
 
 
 
 
101
 
102
- with tempfile.TemporaryDirectory() as tmp_dir:
103
- save_path = os.path.join(tmp_dir, "model.flashpack")
104
- model.save_flashpack(save_path)
105
- repo = Repository(local_dir=tmp_dir, clone_from="rahul7star/FlashPack", use_auth_token=True)
106
- repo.push_to_hub()
107
- print("βœ… Model pushed to Hugging Face Hub!")
 
108
 
109
- return model, dataset, embed_model, tokenizer, long_embeddings
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  # ============================================================
113
- # Try loading from HF first
114
  # ============================================================
115
- def get_flashpack_model(hf_repo="rahul7star/FlashPack", input_dim=768, output_dim=768):
116
  try:
117
  print(f"πŸ” Loading FlashPack model from {hf_repo}...")
118
- dummy_model = GemmaTrainer(input_dim, 512, output_dim)
119
  model = GemmaTrainer.from_flashpack(hf_repo, model=dummy_model)
 
120
  print("βœ… Model loaded successfully.")
121
- tokenizer, embed_model, _ = build_encoder("gpt2", max_length=32)
122
  return model, tokenizer, embed_model, None, None
123
  except Exception as e:
124
  print(f"⚠️ Load failed: {e}")
125
  print("⏬ Training new model...")
126
  return train_flashpack_model()
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # ============================================================
130
- # Gradio UI
131
  # ============================================================
132
- with gr.Blocks(title="FlashPack Trainer") as demo:
133
- gr.Markdown("## ⚑ FlashPack Trainer")
134
 
135
- status = gr.Textbox(label="Status", value="Ready", interactive=False)
136
- train_btn = gr.Button("Train")
 
 
 
 
 
137
 
138
  model_state = gr.State(None)
139
  tokenizer_state = gr.State(None)
@@ -142,15 +235,10 @@ with gr.Blocks(title="FlashPack Trainer") as demo:
142
  long_embeddings_state = gr.State(None)
143
 
144
  def train_model():
 
145
  model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
146
- return (
147
- model,
148
- tokenizer,
149
- embed_model,
150
- dataset,
151
- long_embeddings,
152
- "βœ… Model ready for use",
153
- )
154
 
155
  train_btn.click(
156
  train_model,
@@ -158,4 +246,32 @@ with gr.Blocks(title="FlashPack Trainer") as demo:
158
  outputs=[model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state, status],
159
  )
160
 
161
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
  import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import tempfile
7
+ import gradio as gr
8
  from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
+ from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository
12
+ from typing import Tuple
13
+
14
+ # ============================================================
15
+ # πŸ–₯ Device setup (CPU-only)
16
+ # ============================================================
17
+ device = torch.device("cpu")
18
+ torch.set_num_threads(4)
19
+ print(f"πŸ”§ Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
+ # 1️⃣ FlashPack model
23
  # ============================================================
24
+ class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
29
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
30
+ self.fc3 = nn.Linear(hidden_dim, output_dim)
31
 
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ x = self.fc1(x)
34
+ x = self.relu(x)
35
+ x = self.fc2(x)
36
+ x = self.relu(x)
37
+ x = self.fc3(x)
38
+ return x
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # ============================================================
41
+ # 2️⃣ Encoder (mean + max pooling)
42
  # ============================================================
43
+ def build_encoder(model_name="gpt2", max_length: int = 128):
44
+ print(f"πŸ“¦ Loading tokenizer and model for {model_name}...")
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  if tokenizer.pad_token is None:
47
  tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ embed_model = AutoModel.from_pretrained(model_name).to(device)
50
  embed_model.eval()
51
+ print(f"βœ… Encoder ready: {model_name}")
52
 
53
  @torch.no_grad()
54
  def encode(prompt: str) -> torch.Tensor:
55
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
56
+ padding="max_length", max_length=max_length).to(device)
57
+ last_hidden = embed_model(**inputs).last_hidden_state
58
+ mean_pool = last_hidden.mean(dim=1)
59
+ max_pool, _ = last_hidden.max(dim=1)
60
+ return torch.cat([mean_pool, max_pool], dim=1).cpu() # double dimension
 
 
 
61
 
62
  return tokenizer, embed_model, encode
63
 
64
  # ============================================================
65
+ # 3️⃣ Push FlashPack model to HF
66
+ # ============================================================
67
+ def push_flashpack_model_to_hf(model, hf_repo: str):
68
+ logs = []
69
+ with tempfile.TemporaryDirectory() as tmp_dir:
70
+ logs.append(f"πŸ“‚ Temporary directory: {tmp_dir}")
71
+ repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
72
+ pack_path = os.path.join(tmp_dir, "model.flashpack")
73
+ model.save_flashpack(pack_path, target_dtype=torch.float32)
74
+ readme_path = os.path.join(tmp_dir, "README.md")
75
+ with open(readme_path, "w") as f:
76
+ f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
77
+ repo.push_to_hub()
78
+ logs.append(f"βœ… Model pushed to HF: {hf_repo}")
79
+ return logs
80
+
81
+ # ============================================================
82
+ # 4️⃣ Train FlashPack model
83
  # ============================================================
84
  def train_flashpack_model(
85
  dataset_name: str = "rahul7star/prompt-enhancer-dataset",
86
+ max_encode: int = 1000,
87
+ hidden_dim: int = 1024,
88
+ hf_repo: str = "rahul7star/FlashPack",
89
+ push_to_hub: bool = True
90
+ ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
91
+
92
  print("πŸ“¦ Loading dataset...")
93
  dataset = load_dataset(dataset_name, split="train")
94
+ limit = min(max_encode, len(dataset))
95
+ dataset = dataset.select(range(limit))
96
+ print(f"⚑ Using {len(dataset)} prompts for training")
97
 
98
+ # Split train/test
99
+ train_size = int(0.9 * len(dataset))
100
+ dataset_train = dataset.select(range(train_size))
101
+ dataset_test = dataset.select(range(train_size, len(dataset)))
102
+ print(f"πŸ§ͺ Train/Test split: {len(dataset_train)} / {len(dataset_test)}")
103
 
104
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
 
 
 
 
 
105
 
106
+ def encode_dataset(ds):
107
+ short_list, long_list = [], []
108
+ for i, item in enumerate(ds):
109
+ short_list.append(encode_fn(item["short_prompt"]))
110
+ long_list.append(encode_fn(item["long_prompt"]))
111
+ if (i+1) % 50 == 0 or (i+1) == len(ds):
112
+ print(f" β†’ Encoded {i+1}/{len(ds)} prompts")
113
+ gc.collect()
114
+ return torch.vstack(short_list), torch.vstack(long_list)
115
 
116
+ short_train, long_train = encode_dataset(dataset_train)
117
+ short_test, long_test = encode_dataset(dataset_test)
118
+ print(f"βœ… Embeddings shapes: train {short_train.shape}/{long_train.shape}, test {short_test.shape}/{long_test.shape}")
119
 
120
+ input_dim = short_train.shape[1]
121
+ output_dim = long_train.shape[1]
122
+ model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
123
+
124
+ criterion = nn.CosineSimilarity(dim=1)
125
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
126
+ max_epochs = 50
127
+ batch_size = 32
128
+ n = short_train.shape[0]
129
 
130
  print("πŸš€ Training model...")
131
+ for epoch in range(max_epochs):
132
  model.train()
133
+ epoch_loss = 0.0
134
+ perm = torch.randperm(n)
135
+ for start in range(0, n, batch_size):
136
+ idx = perm[start:start+batch_size]
137
+ inputs = short_train[idx].to(device)
138
+ targets = long_train[idx].to(device)
139
+ optimizer.zero_grad()
140
+ outputs = model(inputs)
141
+ loss = 1 - criterion(outputs, targets).mean()
142
+ loss.backward()
143
+ optimizer.step()
144
+ epoch_loss += loss.item() * inputs.size(0)
145
+ epoch_loss /= n
146
 
147
+ model.eval()
148
+ with torch.no_grad():
149
+ test_outputs = model(short_test.to(device))
150
+ test_loss = 1 - criterion(test_outputs, long_test.to(device)).mean()
151
+
152
+ if epoch % 5 == 0 or epoch == max_epochs - 1:
153
+ print(f"Epoch {epoch+1}/{max_epochs} β€” Train Loss: {epoch_loss:.6f}, Test Loss: {test_loss:.6f}")
154
 
155
+ if test_loss < 0.01:
156
+ print("🎯 Early stopping β€” loss threshold reached.")
157
+ break
158
+
159
+ print("βœ… Training complete!")
160
+
161
+ if push_to_hub and test_loss < 0.05:
162
+ push_flashpack_model_to_hf(model, hf_repo)
163
+ else:
164
+ print("⚠️ Model not pushed β€” test loss not low enough.")
165
+
166
+ return model, dataset, embed_model, tokenizer, long_train
167
 
168
 
169
  # ============================================================
170
+ # 5️⃣ Load or train
171
  # ============================================================
172
+ def get_flashpack_model(hf_repo="rahul7star/FlashPack", input_dim=1536, output_dim=1536):
173
  try:
174
  print(f"πŸ” Loading FlashPack model from {hf_repo}...")
175
+ dummy_model = GemmaTrainer(input_dim=input_dim, hidden_dim=1024, output_dim=output_dim)
176
  model = GemmaTrainer.from_flashpack(hf_repo, model=dummy_model)
177
+ model.eval()
178
  print("βœ… Model loaded successfully.")
179
+ tokenizer, embed_model, _ = build_encoder("gpt2", max_length=128)
180
  return model, tokenizer, embed_model, None, None
181
  except Exception as e:
182
  print(f"⚠️ Load failed: {e}")
183
  print("⏬ Training new model...")
184
  return train_flashpack_model()
185
 
186
+ # ============================================================
187
+ # 6️⃣ Encode and Enhance
188
+ # ============================================================
189
+ @torch.no_grad()
190
+ def encode_prompt(prompt, tokenizer, embed_model):
191
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
192
+ padding="max_length", max_length=128).to(device)
193
+ last_hidden = embed_model(**inputs).last_hidden_state
194
+ mean_pool = last_hidden.mean(dim=1)
195
+ max_pool, _ = last_hidden.max(dim=1)
196
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
197
+
198
+ def make_enhance_fn(model, long_embeddings, dataset, tokenizer, embed_model):
199
+ @torch.no_grad()
200
+ def fn(user_prompt, chat_history):
201
+ if model is None or dataset is None or long_embeddings is None:
202
+ return [{"role": "system", "content": "⚠️ Model not loaded. Please train or reload first."}]
203
+ chat_history = chat_history or []
204
+ short_emb = encode_prompt(user_prompt, tokenizer, embed_model)
205
+ mapped = model(short_emb.to(device)).cpu()
206
+ sims = (long_embeddings @ mapped.t()).squeeze(1)
207
+ long_norms = long_embeddings.norm(dim=1)
208
+ mapped_norm = mapped.norm()
209
+ sims = sims / (long_norms * (mapped_norm + 1e-12))
210
+ best_idx = int(sims.argmax().item())
211
+ enhanced_prompt = dataset[best_idx]["long_prompt"]
212
+ chat_history.append({"role": "user", "content": user_prompt})
213
+ chat_history.append({"role": "assistant", "content": enhanced_prompt})
214
+ return chat_history
215
+ return fn
216
 
217
  # ============================================================
218
+ # 7️⃣ Gradio UI with Train Button
219
  # ============================================================
220
+ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown("# ✨ Prompt Enhancer (FlashPack mapper)")
222
 
223
+ status = gr.Textbox(value="Model loading...", label="Status", interactive=False)
224
+
225
+ chatbot = gr.Chatbot(label="Enhanced Prompts", type="messages", height=400)
226
+ user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
227
+ send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
228
+ train_btn = gr.Button("🧠 Train Model")
229
+ clear_btn = gr.Button("🧹 Clear Chat")
230
 
231
  model_state = gr.State(None)
232
  tokenizer_state = gr.State(None)
 
235
  long_embeddings_state = gr.State(None)
236
 
237
  def train_model():
238
+ status_text = "πŸ”„ Training or loading model..."
239
  model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
240
+ status_text = "βœ… Model ready!"
241
+ return model, tokenizer, embed_model, dataset, long_embeddings, status_text
 
 
 
 
 
 
242
 
243
  train_btn.click(
244
  train_model,
 
246
  outputs=[model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state, status],
247
  )
248
 
249
+ def enhance(user_prompt, chat_history, model, tokenizer, embed_model, dataset, long_embeddings):
250
+ fn = make_enhance_fn(model, long_embeddings, dataset, tokenizer, embed_model)
251
+ return fn(user_prompt, chat_history)
252
+
253
+ send_btn.click(
254
+ enhance,
255
+ inputs=[user_prompt, chatbot, model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state],
256
+ outputs=chatbot,
257
+ )
258
+ user_prompt.submit(
259
+ enhance,
260
+ inputs=[user_prompt, chatbot, model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state],
261
+ outputs=chatbot,
262
+ )
263
+
264
+ clear_btn.click(lambda: [], None, chatbot)
265
+
266
+ # Auto-load model on startup
267
+ demo.load(
268
+ lambda: train_model(),
269
+ inputs=None,
270
+ outputs=[model_state, tokenizer_state, embed_model_state, dataset_state, long_embeddings_state, status],
271
+ )
272
+
273
+ # ============================================================
274
+ # 9️⃣ Launch
275
+ # ============================================================
276
+ if __name__ == "__main__":
277
+ demo.launch(show_error=True)