rahul7star commited on
Commit
588725c
·
verified ·
1 Parent(s): 5d713c7

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +19 -54
app_flash1.py CHANGED
@@ -133,75 +133,40 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
133
  # Lazy Load / Get Model
134
  # ===========================
135
  # ===========================
136
- # Lazy Load / Get Model (Fixed)
137
- # ===========================
138
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
139
- """
140
- Loads the FlashPack model + dataset + long embeddings from HF repo if available,
141
- otherwise trains a new model locally.
142
- Returns:
143
- model, tokenizer, embed_model, enhance_fn, dataset, long_embeddings
144
- """
145
  local_model_path = "model.flashpack"
146
 
147
- try:
148
- print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
149
- # 1️⃣ Download model from HF
150
- files = list_repo_files(hf_repo)
151
- if "model.flashpack" in files:
152
- local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
153
- print("✅ Model downloaded from HF")
154
- else:
155
- print("🚫 No pretrained model found in HF, will train locally")
156
- raise FileNotFoundError
157
-
158
- # 2️⃣ Load FlashPack model
159
- model = GemmaTrainer().from_flashpack(local_model_path)
160
- model.eval()
161
-
162
- # 3️⃣ Load encoder
163
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
164
-
165
- # 4️⃣ Try loading dataset + long embeddings from HF
166
  try:
167
- dataset = load_dataset("rahul7star/prompt-enhancer-dataset", split="train").select(range(1000))
168
- # Encode long embeddings
169
- long_embeddings_list = []
170
- for item in dataset:
171
- long_embeddings_list.append(encode_fn(item["long_prompt"]))
172
- long_embeddings = torch.vstack(long_embeddings_list)
 
173
  except Exception as e:
174
- print(f"⚠️ Could not load dataset/embeddings from HF: {e}")
175
- dataset = None
176
- long_embeddings = None
177
 
178
- except Exception:
179
- # If anything fails, train locally
180
- print("⏬ Training a new FlashPack model locally...")
181
- model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model()
182
- push_flashpack_model_to_hf(model, hf_repo)
183
 
184
- # 5️⃣ Enhance function using embeddings to select best long prompt
185
  @torch.no_grad()
186
  def enhance_fn(prompt, chat):
187
  chat = chat or []
188
  short_emb = encode_fn(prompt).to(device)
189
  mapped = model(short_emb).cpu()
190
-
191
- if dataset is not None and long_embeddings is not None:
192
- # Cosine similarity
193
- sims = (long_embeddings @ mapped.t()).squeeze(1)
194
- sims = sims / (long_embeddings.norm(dim=1) * (mapped.norm() + 1e-12))
195
- best_idx = int(sims.argmax().item())
196
- enhanced_prompt = dataset[best_idx]["long_prompt"]
197
- else:
198
- enhanced_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
199
-
200
  chat.append({"role": "user", "content": prompt})
201
- chat.append({"role": "assistant", "content": enhanced_prompt})
202
  return chat
203
 
204
- return model, tokenizer, embed_model, enhance_fn, dataset, long_embeddings
205
 
206
  # ===========================
207
  # Gradio UI
 
133
  # Lazy Load / Get Model
134
  # ===========================
135
  # ===========================
 
 
136
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
 
 
 
 
 
 
137
  local_model_path = "model.flashpack"
138
 
139
+ if os.path.exists(local_model_path):
140
+ print(" Loading local model")
141
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  try:
143
+ files = list_repo_files(hf_repo)
144
+ if "model.flashpack" in files:
145
+ print("✅ Downloading model from HF")
146
+ local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
147
+ else:
148
+ print("🚫 No pretrained model found")
149
+ return None, None, None, None
150
  except Exception as e:
151
+ print(f"⚠️ Error accessing HF: {e}")
152
+ return None, None, None, None
 
153
 
154
+ # ⚡ Use input_dim=1536 (default)
155
+ model = GemmaTrainer(input_dim=1536).from_flashpack(local_model_path)
156
+ model.eval()
157
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2")
 
158
 
 
159
  @torch.no_grad()
160
  def enhance_fn(prompt, chat):
161
  chat = chat or []
162
  short_emb = encode_fn(prompt).to(device)
163
  mapped = model(short_emb).cpu()
164
+ long_prompt = f"🌟 Enhanced prompt (embedding-based) for: {prompt}"
 
 
 
 
 
 
 
 
 
165
  chat.append({"role": "user", "content": prompt})
166
+ chat.append({"role": "assistant", "content": long_prompt})
167
  return chat
168
 
169
+ return model, tokenizer, embed_model, enhance_fn
170
 
171
  # ===========================
172
  # Gradio UI