rahul7star commited on
Commit
9490127
Β·
verified Β·
1 Parent(s): a69d346

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +34 -26
app_flash1.py CHANGED
@@ -60,14 +60,16 @@ def build_encoder(model_name="gpt2", max_length=128):
60
  # ===========================
61
  # Push model to HF
62
  # ===========================
63
- def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
 
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
66
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
 
69
  repo.push_to_hub()
70
- print(f"βœ… Model pushed to {hf_repo}")
71
 
72
  # ===========================
73
  # Training
@@ -75,8 +77,16 @@ def push_flashpack_model_to_hf(model, hf_repo):
75
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
76
  hf_repo="rahul7star/FlashPack",
77
  max_encode=1000):
78
- status = "πŸ“¦ Loading dataset..."
 
 
 
 
 
 
79
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
 
 
80
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
81
 
82
  def encode_dataset(ds):
@@ -85,7 +95,7 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
85
  s_list.append(encode_fn(item["short_prompt"]))
86
  l_list.append(encode_fn(item["long_prompt"]))
87
  if (i + 1) % 50 == 0:
88
- print(f" β†’ Encoded {i + 1}/{len(ds)}")
89
  gc.collect()
90
  return torch.vstack(s_list), torch.vstack(l_list)
91
 
@@ -94,7 +104,7 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
94
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
95
  loss_fn = nn.CosineSimilarity(dim=1)
96
 
97
- print("πŸš€ Training model...")
98
  for epoch in range(20):
99
  model.train()
100
  optimizer.zero_grad()
@@ -102,12 +112,12 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
102
  loss = 1 - loss_fn(preds, long_emb).mean()
103
  loss.backward()
104
  optimizer.step()
105
- print(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
106
  if loss.item() < 0.01:
107
- print("🎯 Early stopping.")
108
  break
109
 
110
- push_flashpack_model_to_hf(model, hf_repo)
111
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
112
 
113
  @torch.no_grad()
@@ -120,7 +130,12 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
120
  chat.append({"role": "assistant", "content": long_prompt})
121
  return chat
122
 
123
- return model, tokenizer, embed_model, enhance_fn
 
 
 
 
 
124
 
125
  # ===========================
126
  # Lazy Load / Get Model
@@ -128,11 +143,9 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
128
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
129
  local_model_path = "model.flashpack"
130
 
131
- # 1. Try local
132
  if os.path.exists(local_model_path):
133
  print("βœ… Loading local model")
134
  else:
135
- # 2. Try HF
136
  try:
137
  files = list_repo_files(hf_repo)
138
  if "model.flashpack" in files:
@@ -140,7 +153,6 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
140
  local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
141
  else:
142
  print("🚫 No pretrained model found")
143
- # Return None to indicate missing model
144
  return None, None, None, None
145
  except Exception as e:
146
  print(f"⚠️ Error accessing HF: {e}")
@@ -173,12 +185,11 @@ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
173
  send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
174
  clear_btn = gr.Button("🧹 Clear")
175
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
176
- status = gr.Markdown("Status: Loading model...")
177
 
178
- # ===========================
179
  # Lazy load model
180
- # ===========================
181
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
 
182
 
183
  if enhance_fn is None:
184
  def enhance_fn(prompt, chat):
@@ -186,26 +197,23 @@ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
186
  chat.append({"role": "assistant", "content":
187
  "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
188
  return chat
189
-
190
- status.update("⚠️ No pretrained model found. Ready to train.")
191
  else:
192
- status.update("βœ… Model loaded β€” ready to enhance.")
193
 
194
- # ===========================
195
  # Button callbacks
196
- # ===========================
197
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
198
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
199
  clear_btn.click(lambda: [], None, chatbot)
200
 
201
  def retrain():
202
- global model, tokenizer, embed_model, enhance_fn
203
- status.update("πŸš€ Training model, please wait...")
204
- model, tokenizer, embed_model, enhance_fn = train_flashpack_model()
205
- status.update("βœ… Model retrained and pushed to HF!")
206
- return "βœ… Model retrained and ready!"
207
 
208
- train_btn.click(retrain, None, status)
209
 
210
  if __name__ == "__main__":
211
  demo.launch(show_error=True)
 
60
  # ===========================
61
  # Push model to HF
62
  # ===========================
63
+ def push_flashpack_model_to_hf(model, hf_repo, log_fn):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
+ log_fn(f"πŸ“¦ Preparing repository {hf_repo}...")
66
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
67
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32)
68
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
69
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
70
+ log_fn("⏳ Pushing model to Hugging Face...")
71
  repo.push_to_hub()
72
+ log_fn(f"βœ… Model pushed to {hf_repo}")
73
 
74
  # ===========================
75
  # Training
 
77
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
78
  hf_repo="rahul7star/FlashPack",
79
  max_encode=1000):
80
+ logs = []
81
+
82
+ def log_fn(msg):
83
+ logs.append(msg)
84
+ print(msg)
85
+
86
+ log_fn("πŸ“¦ Loading dataset...")
87
  dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
88
+ log_fn(f"βœ… Loaded {len(dataset)} samples")
89
+
90
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
91
 
92
  def encode_dataset(ds):
 
95
  s_list.append(encode_fn(item["short_prompt"]))
96
  l_list.append(encode_fn(item["long_prompt"]))
97
  if (i + 1) % 50 == 0:
98
+ log_fn(f" β†’ Encoded {i + 1}/{len(ds)}")
99
  gc.collect()
100
  return torch.vstack(s_list), torch.vstack(l_list)
101
 
 
104
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
105
  loss_fn = nn.CosineSimilarity(dim=1)
106
 
107
+ log_fn("πŸš€ Training model...")
108
  for epoch in range(20):
109
  model.train()
110
  optimizer.zero_grad()
 
112
  loss = 1 - loss_fn(preds, long_emb).mean()
113
  loss.backward()
114
  optimizer.step()
115
+ log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
116
  if loss.item() < 0.01:
117
+ log_fn("🎯 Early stopping.")
118
  break
119
 
120
+ push_flashpack_model_to_hf(model, hf_repo, log_fn)
121
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
122
 
123
  @torch.no_grad()
 
130
  chat.append({"role": "assistant", "content": long_prompt})
131
  return chat
132
 
133
+ # Test model on sample prompt
134
+ test_prompt = "Hello world"
135
+ enhance_fn(test_prompt, [])
136
+ log_fn(f"βœ… Model test complete: '{test_prompt}' -> Enhanced prompt available")
137
+
138
+ return model, tokenizer, embed_model, enhance_fn, logs
139
 
140
  # ===========================
141
  # Lazy Load / Get Model
 
143
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
144
  local_model_path = "model.flashpack"
145
 
 
146
  if os.path.exists(local_model_path):
147
  print("βœ… Loading local model")
148
  else:
 
149
  try:
150
  files = list_repo_files(hf_repo)
151
  if "model.flashpack" in files:
 
153
  local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
154
  else:
155
  print("🚫 No pretrained model found")
 
156
  return None, None, None, None
157
  except Exception as e:
158
  print(f"⚠️ Error accessing HF: {e}")
 
185
  send_btn = gr.Button("πŸš€ Enhance Prompt", variant="primary")
186
  clear_btn = gr.Button("🧹 Clear")
187
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
188
+ log_output = gr.Textbox(label="Logs", lines=15)
189
 
 
190
  # Lazy load model
 
191
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
192
+ logs = []
193
 
194
  if enhance_fn is None:
195
  def enhance_fn(prompt, chat):
 
197
  chat.append({"role": "assistant", "content":
198
  "⚠️ No pretrained model found. Please click 'Train Model' to create one."})
199
  return chat
200
+ logs.append("⚠️ No pretrained model found. Ready to train.")
 
201
  else:
202
+ logs.append("βœ… Model loaded β€” ready to enhance.")
203
 
 
204
  # Button callbacks
 
205
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
206
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
207
  clear_btn.click(lambda: [], None, chatbot)
208
 
209
  def retrain():
210
+ global model, tokenizer, embed_model, enhance_fn, logs
211
+ logs = ["πŸš€ Training model, please wait..."]
212
+ model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model()
213
+ logs.extend(train_logs)
214
+ return gr.Textbox.update(value="\n".join(logs))
215
 
216
+ train_btn.click(retrain, None, log_output)
217
 
218
  if __name__ == "__main__":
219
  demo.launch(show_error=True)