rahul7star commited on
Commit
a69d346
Β·
verified Β·
1 Parent(s): 4c22dd6

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +43 -14
app_flash1.py CHANGED
@@ -63,7 +63,7 @@ def build_encoder(model_name="gpt2", max_length=128):
63
  def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
66
- model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"),target_dtype=torch.float32)
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
69
  repo.push_to_hub()
@@ -75,7 +75,7 @@ def push_flashpack_model_to_hf(model, hf_repo):
75
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
76
  hf_repo="rahul7star/FlashPack",
77
  max_encode=1000):
78
- print("πŸ“¦ Loading dataset...")
79
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
80
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
81
 
@@ -108,10 +108,22 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
108
  break
109
 
110
  push_flashpack_model_to_hf(model, hf_repo)
111
- return model, tokenizer, embed_model
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # ===========================
114
- # Load or Train
115
  # ===========================
116
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
117
  local_model_path = "model.flashpack"
@@ -127,11 +139,12 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
127
  print("βœ… Downloading model from HF")
128
  local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
129
  else:
130
- print("🚫 Model not found on HF β€” will train a new model")
131
- return train_flashpack_model(hf_repo=hf_repo)
 
132
  except Exception as e:
133
- print(f"⚠️ Error accessing HF: {e}. Training new model instead.")
134
- return train_flashpack_model(hf_repo=hf_repo)
135
 
136
  model = GemmaTrainer().from_flashpack(local_model_path)
137
  model.eval()
@@ -142,7 +155,6 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
142
  chat = chat or []
143
  short_emb = encode_fn(prompt)
144
  mapped = model(short_emb.to(device)).cpu()
145
- # Simply return a placeholder text for demonstration
146
  long_prompt = f"βœ… Enhanced long prompt for: {prompt}"
147
  chat.append({"role": "user", "content": prompt})
148
  chat.append({"role": "assistant", "content": long_prompt})
@@ -161,20 +173,37 @@ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
161
  send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
162
  clear_btn = gr.Button("🧹 Clear")
163
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
164
- status = gr.Markdown("Status: Ready")
165
 
166
- # Load or train model
 
 
167
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
170
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
171
  clear_btn.click(lambda: [], None, chatbot)
172
 
173
  def retrain():
174
  global model, tokenizer, embed_model, enhance_fn
175
- model, tokenizer, embed_model = train_flashpack_model()
176
- enhance_fn = get_flashpack_model()[3]
177
- return "βœ… Model retrained and pushed to HF!"
 
178
 
179
  train_btn.click(retrain, None, status)
180
 
 
63
  def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
66
+ model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
69
  repo.push_to_hub()
 
75
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
76
  hf_repo="rahul7star/FlashPack",
77
  max_encode=1000):
78
+ status = "πŸ“¦ Loading dataset..."
79
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
80
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
81
 
 
108
  break
109
 
110
  push_flashpack_model_to_hf(model, hf_repo)
111
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2")
112
+
113
+ @torch.no_grad()
114
+ def enhance_fn(prompt, chat):
115
+ chat = chat or []
116
+ short_emb = encode_fn(prompt)
117
+ mapped = model(short_emb.to(device)).cpu()
118
+ long_prompt = f"βœ… Enhanced long prompt for: {prompt}"
119
+ chat.append({"role": "user", "content": prompt})
120
+ chat.append({"role": "assistant", "content": long_prompt})
121
+ return chat
122
+
123
+ return model, tokenizer, embed_model, enhance_fn
124
 
125
  # ===========================
126
+ # Lazy Load / Get Model
127
  # ===========================
128
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
129
  local_model_path = "model.flashpack"
 
139
  print("βœ… Downloading model from HF")
140
  local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
141
  else:
142
+ print("🚫 No pretrained model found")
143
+ # Return None to indicate missing model
144
+ return None, None, None, None
145
  except Exception as e:
146
+ print(f"⚠️ Error accessing HF: {e}")
147
+ return None, None, None, None
148
 
149
  model = GemmaTrainer().from_flashpack(local_model_path)
150
  model.eval()
 
155
  chat = chat or []
156
  short_emb = encode_fn(prompt)
157
  mapped = model(short_emb.to(device)).cpu()
 
158
  long_prompt = f"βœ… Enhanced long prompt for: {prompt}"
159
  chat.append({"role": "user", "content": prompt})
160
  chat.append({"role": "assistant", "content": long_prompt})
 
173
  send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
174
  clear_btn = gr.Button("🧹 Clear")
175
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
176
+ status = gr.Markdown("Status: Loading model...")
177
 
178
+ # ===========================
179
+ # Lazy load model
180
+ # ===========================
181
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
182
 
183
+ if enhance_fn is None:
184
+ def enhance_fn(prompt, chat):
185
+ chat = chat or []
186
+ chat.append({"role": "assistant", "content":
187
+ "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
188
+ return chat
189
+
190
+ status.update("⚠️ No pretrained model found. Ready to train.")
191
+ else:
192
+ status.update("βœ… Model loaded β€” ready to enhance.")
193
+
194
+ # ===========================
195
+ # Button callbacks
196
+ # ===========================
197
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
198
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
199
  clear_btn.click(lambda: [], None, chatbot)
200
 
201
  def retrain():
202
  global model, tokenizer, embed_model, enhance_fn
203
+ status.update("πŸš€ Training model, please wait...")
204
+ model, tokenizer, embed_model, enhance_fn = train_flashpack_model()
205
+ status.update("βœ… Model retrained and pushed to HF!")
206
+ return "βœ… Model retrained and ready!"
207
 
208
  train_btn.click(retrain, None, status)
209