Keeby-smilyai commited on
Commit
4912b20
Β·
verified Β·
1 Parent(s): acbcf43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +483 -0
app.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------
2
+ # app.py (CPU COMPATIBLE VERSION)
3
+ #
4
+ # This file contains the backend logic and Gradio UI for the chatbot.
5
+ #
6
+ # --- FINAL, WORKING VERSION ---
7
+ # - Specifies target_modules in LoraConfig to work with the custom Sam2 model.
8
+ # - Uses a pure PyTorch fine-tuning loop for maximum control and stability.
9
+ # - Custom Sam2Config inherits from PretrainedConfig to solve subscriptable errors.
10
+ # - UI polling is backward-compatible with older Gradio versions.
11
+ # -------------------------------
12
+ import time
13
+ import math
14
+ import json
15
+ import requests
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.utils.data import DataLoader
20
+ from pathlib import Path
21
+ from safetensors.torch import load_file
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
+ import gradio as gr
24
+ import os
25
+ from datetime import datetime
26
+ import threading
27
+ import time
28
+ import traceback
29
+
30
+ # --- RLHF & Training Imports ---
31
+ from huggingface_hub import HfApi, login
32
+ from datasets import Dataset, load_dataset, concatenate_datasets
33
+ from peft import LoraConfig, get_peft_model
34
+
35
+ # -------------------------------
36
+ # 0) RLHF & TUNING CONFIGURATION
37
+ # -------------------------------
38
+ FEEDBACK_DATASET_REPO = "Smilyai-labs/Open-Sam-2.5-chat"
39
+ TUNED_MODEL_REPO_OWNER = "Smilyai-labs"
40
+ BASE_MODEL_REPO = "Smilyai-labs/Sam-2.5-PRO-SOLVER-V2"
41
+ FINETUNE_TRIGGER_LIKES = 2
42
+ MIN_LIKES_FOR_TRAINING = 2
43
+
44
+ # --- PyTorch Training Config ---
45
+ LEARNING_RATE = 2e-4
46
+ NUM_EPOCHS = 1
47
+ BATCH_SIZE = 1
48
+
49
+ # --- Login to Hugging Face Hub ---
50
+ HF_TOKEN = os.getenv("HF_TOKEN")
51
+ if not HF_TOKEN:
52
+ print("WARNING: Hugging Face token not found. Feedback will not be saved and tuning will not run.")
53
+ else:
54
+ login(token=HF_TOKEN)
55
+ print("Hugging Face token found. Feedback logging and model tuning are enabled.")
56
+
57
+ # --- Global state ---
58
+ LIKE_COUNTER = 0
59
+ like_counter_lock = threading.Lock()
60
+ training_lock = threading.Lock()
61
+ model_lock = threading.Lock()
62
+ TRAINING_STATUS = ""
63
+
64
+ # -------------------------------
65
+ # 1) Local Sam-2 architecture
66
+ # -------------------------------
67
+ class Sam2Config(PretrainedConfig):
68
+ model_type = "sam2"
69
+
70
+ def __init__(
71
+ self,
72
+ vocab_size=32000,
73
+ d_model=384,
74
+ n_layers=6,
75
+ n_heads=6,
76
+ ff_mult=4.0,
77
+ dropout=0.1,
78
+ input_modality="text",
79
+ head_type="causal_lm",
80
+ version="0.1",
81
+ **kwargs
82
+ ):
83
+ self.vocab_size = vocab_size
84
+ self.d_model = d_model
85
+ self.n_layers = n_layers
86
+ self.n_heads = n_heads
87
+ self.ff_mult = ff_mult
88
+ self.dropout = dropout
89
+ self.input_modality = input_modality
90
+ self.head_type = head_type
91
+ self.version = version
92
+ super().__init__(**kwargs)
93
+
94
+ class RMSNorm(nn.Module):
95
+ def __init__(self, d, eps=1e-6):
96
+ super().__init__()
97
+ self.eps = eps
98
+ self.weight = nn.Parameter(torch.ones(d))
99
+ def forward(self, x):
100
+ return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
101
+
102
+ class MHA(nn.Module):
103
+ def __init__(self, d_model, n_heads, dropout=0.0):
104
+ super().__init__()
105
+ assert d_model % n_heads == 0
106
+ self.n_heads = n_heads
107
+ self.head_dim = d_model // n_heads
108
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
109
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
110
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
111
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
112
+ self.dropout = nn.Dropout(dropout)
113
+ def forward(self, x, attn_mask=None):
114
+ B, T, C = x.shape
115
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
116
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
117
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
118
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
119
+ causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
120
+ scores = scores.masked_fill(causal, float("-inf"))
121
+ if attn_mask is not None:
122
+ scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf"))
123
+ attn = torch.softmax(scores, dim=-1)
124
+ out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C)
125
+ return self.out_proj(out)
126
+
127
+ class SwiGLU(nn.Module):
128
+ def __init__(self, d_model, d_ff, dropout=0.0):
129
+ super().__init__()
130
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
131
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
132
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
133
+ self.dropout = nn.Dropout(dropout)
134
+ def forward(self, x):
135
+ return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x)))
136
+
137
+ class Block(nn.Module):
138
+ def __init__(self, d_model, n_heads, ff_mult, dropout=0.0):
139
+ super().__init__()
140
+ self.norm1 = RMSNorm(d_model)
141
+ self.attn = MHA(d_model, n_heads, dropout=dropout)
142
+ self.norm2 = RMSNorm(d_model)
143
+ self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout)
144
+ self.drop = nn.Dropout(dropout)
145
+ def forward(self, x, attn_mask=None):
146
+ x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask))
147
+ x = x + self.drop(self.ff(self.norm2(x)))
148
+ return x
149
+
150
+ class Sam2(PreTrainedModel): # <-- CHANGE THIS LINE: inherit from PreTrainedModel
151
+ config_class = Sam2Config # <-- ADD THIS LINE: tell HF what config class to use
152
+
153
+ def __init__(self, config: Sam2Config):
154
+ super().__init__(config) # <-- CHANGE THIS LINE: pass config to parent
155
+ self.config = config # You can keep this if you use it elsewhere
156
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
157
+ self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)])
158
+ self.norm = RMSNorm(config.d_model)
159
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
160
+ self.lm_head.weight = self.embed.weight
161
+
162
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
163
+ return {"input_ids": input_ids}
164
+
165
+ def forward(self, input_ids=None, inputs_embeds=None, attention_mask=None, labels=None, **kwargs):
166
+ if inputs_embeds is not None:
167
+ x = inputs_embeds
168
+ else:
169
+ if input_ids is None:
170
+ raise ValueError("You must provide either input_ids or inputs_embeds")
171
+ x = self.embed(input_ids)
172
+
173
+ for blk in self.blocks:
174
+ x = blk(x, attn_mask=attention_mask)
175
+ x = self.norm(x)
176
+ logits = self.lm_head(x)
177
+ loss = None
178
+ if labels is not None:
179
+ shift_logits = logits[..., :-1, :].contiguous()
180
+ shift_labels = labels[..., 1:].contiguous()
181
+ loss_fct = nn.CrossEntropyLoss()
182
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
183
+ shift_labels = shift_labels.view(-1)
184
+ shift_labels = shift_labels.to(shift_logits.device)
185
+ loss = loss_fct(shift_logits, shift_labels)
186
+ if loss is not None:
187
+ return (loss, logits)
188
+ return (logits,)
189
+
190
+ # -------------------------------
191
+ # 2) Load initial resources
192
+ # -------------------------------
193
+ weights_filename = "model.safetensors"
194
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_REPO)
195
+ tokenizer.pad_token = tokenizer.eos_token
196
+ # --- FIXED: Removed extra spaces in URLs ---
197
+ config_url = f"https://huggingface.co/{BASE_MODEL_REPO}/raw/main/config.json"
198
+ config_data = requests.get(config_url).json()
199
+ cfg = Sam2Config(**config_data)
200
+
201
+ # --- FIXED: Removed extra spaces in URLs ---
202
+ weights_url = f"https://huggingface.co/{BASE_MODEL_REPO}/resolve/main/{weights_filename}"
203
+ weights_content = requests.get(weights_url).content
204
+ with open(weights_filename, "wb") as f: f.write(weights_content)
205
+
206
+ model = Sam2(cfg)
207
+ state_dict = load_file(weights_filename)
208
+ model.load_state_dict(state_dict)
209
+
210
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
+ model.to(device).eval()
212
+ print(f"Inference will run on: {device}")
213
+
214
+ EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id
215
+ SPECIAL_TOKENS = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>"}
216
+ SYSTEM_PROMPT = "You are Sam-2, a friendly and concise chatbot. Always give short, direct answers and avoid medical or legal advice."
217
+
218
+ AutoModelForCausalLM.register(Sam2Config, Sam2)
219
+
220
+ # -------------------------------
221
+ # 3) Inference and Feedback Functions
222
+ # -------------------------------
223
+ def sample_next_token( logits, past_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1, max_repeat=5, no_repeat_ngram_size=3 ):
224
+ if logits.dim() == 3: logits = logits[:, -1, :].clone()
225
+ else: logits = logits.clone()
226
+ batch_size, vocab_size = logits.size(0), logits.size(1)
227
+ orig_logits = logits.clone()
228
+ if temperature != 1.0: logits = logits / float(temperature)
229
+ past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens)
230
+ for token_id in set(past_list):
231
+ if 0 <= token_id < vocab_size: logits[:, token_id] /= repetition_penalty
232
+ if len(past_list) >= max_repeat:
233
+ last_token, count = past_list[-1], 1
234
+ for i in reversed(past_list[:-1]):
235
+ if i == last_token: count += 1
236
+ else: break
237
+ if count >= max_repeat: logits[:, last_token] = -float("inf")
238
+ if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size:
239
+ ngram = tuple(past_list[-no_repeat_ngram_size:])
240
+ for token_id in range(vocab_size):
241
+ if tuple(past_list[-(no_repeat_ngram_size - 1):] + [token_id]) == ngram: logits[:, token_id] = -float("inf")
242
+ if top_k is not None and top_k > 0:
243
+ tk = min(max(1, int(top_k)), vocab_size)
244
+ topk_vals, _ = torch.topk(logits, tk, dim=-1)
245
+ min_topk = topk_vals[:, -1].unsqueeze(-1)
246
+ logits[logits < min_topk] = -float("inf")
247
+ if top_p is not None and 0.0 < top_p < 1.0:
248
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
249
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
250
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
251
+ for b in range(batch_size):
252
+ sorted_mask = cumulative_probs[b] > top_p
253
+ if sorted_mask.numel() > 0:
254
+ sorted_mask[0] = False
255
+ tokens_to_remove = sorted_indices[b][sorted_mask]
256
+ logits[b, tokens_to_remove] = -float("inf")
257
+ for b in range(batch_size):
258
+ if torch.isneginf(logits[b]).all(): logits[b] = orig_logits[b]
259
+ probs = F.softmax(logits, dim=-1)
260
+ if torch.isnan(probs).any(): probs = torch.ones_like(logits) / logits.size(1)
261
+ next_token = torch.multinomial(probs, num_samples=1)
262
+ return next_token.to(device)
263
+
264
+ def predict(message, history):
265
+ chat_history = []
266
+ for human, assistant in history:
267
+ chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}")
268
+ if assistant:
269
+ chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}")
270
+ chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}")
271
+ prompt = f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']}"
272
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
273
+ input_ids = inputs["input_ids"]
274
+ attention_mask = inputs["attention_mask"]
275
+ generated_text = ""
276
+ for _ in range(256):
277
+ with torch.no_grad(), model_lock:
278
+ outputs = model(input_ids, attention_mask=attention_mask)
279
+ logits = outputs[0]
280
+ next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1)
281
+ token_id = int(next_token.squeeze().item())
282
+ if token_id == EOT_ID: break
283
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
284
+ input_ids = torch.cat([input_ids, next_token], dim=1)
285
+ attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1)
286
+ generated_text += token_str
287
+ yield generated_text
288
+
289
+ def log_feedback(data: gr.LikeData, history: list):
290
+ global LIKE_COUNTER
291
+ if not HF_TOKEN:
292
+ print("Feedback not logged. HF_TOKEN not set.")
293
+ return
294
+ feedback_entry = { "prompt": history[data.index[0]][0], "response": data.value, "feedback": 1 if data.liked else 0, "timestamp": datetime.utcnow().isoformat() }
295
+ new_feedback_dataset = Dataset.from_dict({k: [v] for k, v in feedback_entry.items()})
296
+ try:
297
+ existing_dataset = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache")
298
+ combined_dataset = concatenate_datasets([existing_dataset, new_feedback_dataset])
299
+ except Exception as e:
300
+ print(f"Could not load existing dataset: {e}. Creating a new one.")
301
+ combined_dataset = new_feedback_dataset
302
+ try:
303
+ combined_dataset.push_to_hub(FEEDBACK_DATASET_REPO, private=False)
304
+ feedback_icon = 'πŸ‘' if data.liked else 'πŸ‘Ž'
305
+ print(f"Successfully logged {feedback_icon} feedback. Dataset now has {len(combined_dataset)} entries.")
306
+ if data.liked:
307
+ with like_counter_lock:
308
+ LIKE_COUNTER += 1
309
+ current_likes = LIKE_COUNTER
310
+ print(f"Like recorded. Total likes since start: {current_likes}.")
311
+ if current_likes > 0 and current_likes % FINETUNE_TRIGGER_LIKES == 0:
312
+ print(f"--- Like threshold of {FINETUNE_TRIGGER_LIKES} reached! Triggering fine-tuning. ---")
313
+ tuning_thread = threading.Thread(target=run_tuning_task, daemon=True)
314
+ tuning_thread.start()
315
+ except Exception as e:
316
+ print(f"Error logging feedback to Hub: {e}")
317
+
318
+
319
+ # -------------------------------
320
+ # 6) Background Fine-Tuning Logic (PyTorch Loop)
321
+ # -------------------------------
322
+ def run_tuning_task():
323
+ global model, TRAINING_STATUS
324
+
325
+ if not training_lock.acquire(blocking=False):
326
+ print("Tuning is already in progress. Skipping this trigger.")
327
+ return
328
+
329
+ print("\n--- Starting PyTorch Fine-Tuning Task ---")
330
+ try:
331
+ TRAINING_STATUS = "πŸ”§ Preparing to improve Sam-2.5..."
332
+
333
+ if not HF_TOKEN:
334
+ TRAINING_STATUS = "Error: HF_TOKEN not set. Cannot run tuning."
335
+ time.sleep(10)
336
+ return
337
+
338
+ feedback_data = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache")
339
+ liked_data = feedback_data.filter(lambda x: x['feedback'] == 1)
340
+ print(f"Found {len(liked_data)} total liked responses for training.")
341
+
342
+ if len(liked_data) < MIN_LIKES_FOR_TRAINING:
343
+ TRAINING_STATUS = f"βœ… Improvement complete! (Not enough new data to train, will try again later)."
344
+ time.sleep(5)
345
+ return
346
+
347
+ def format_for_training(example):
348
+ return { "text": f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['user']} {example['prompt']} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['assistant']} {example['response']} {SPECIAL_TOKENS['eot']}"}
349
+ train_dataset = liked_data.map(format_for_training)
350
+
351
+ print("Loading base model for tuning...")
352
+ model_to_tune = Sam2(cfg)
353
+ state_dict_to_tune = load_file(weights_filename)
354
+ model_to_tune.load_state_dict(state_dict_to_tune)
355
+
356
+ # --- THIS IS THE FIX ---
357
+ # We explicitly tell PEFT which linear layers in our MHA block to adapt.
358
+ peft_config = LoraConfig(
359
+ r=16,
360
+ lora_alpha=32,
361
+ lora_dropout=0.05,
362
+ bias="none",
363
+ task_type="CAUSAL_LM",
364
+ target_modules=["q_proj", "v_proj"]
365
+ )
366
+ # --- END FIX ---
367
+
368
+ peft_model = get_peft_model(model_to_tune, peft_config)
369
+ peft_model.to(device)
370
+ peft_model.print_trainable_parameters()
371
+
372
+ tokenized_dataset = train_dataset.map(lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512), batched=True)
373
+ # --- ADDED: Remove the unused 'text' column to clean up the dataset ---
374
+ tokenized_dataset = tokenized_dataset.remove_columns(["text"])
375
+ tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
376
+ train_dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE)
377
+
378
+ optimizer = torch.optim.AdamW(peft_model.parameters(), lr=LEARNING_RATE)
379
+
380
+ TRAINING_STATUS = f"πŸ”§ Sam-2.5 is starting training on {len(liked_data)} examples... Thank you all for your contribution to the dataset. The model will train and hot swap shortly.(This can be slow on CPU)"
381
+ print("Starting model tuning on CPU...")
382
+ peft_model.train()
383
+ for epoch in range(NUM_EPOCHS):
384
+ time.sleep(0.01)
385
+ for i, batch in enumerate(train_dataloader):
386
+ input_ids = batch['input_ids'].to(device)
387
+ attention_mask = batch['attention_mask'].to(device)
388
+ outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
389
+ loss = outputs[0]
390
+ loss.backward()
391
+ optimizer.step()
392
+ optimizer.zero_grad()
393
+ current_loss = loss.item()
394
+ print(f"Epoch {epoch+1}, Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}")
395
+ # --- UPDATE UI WITH LIVE LOSS ---
396
+ TRAINING_STATUS = f"πŸ”§ You are witnessing the training of sam2.5. Training... Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}"
397
+
398
+ print("Tuning complete.")
399
+
400
+ TRAINING_STATUS = "✨ Finishing up... Merging improvements."
401
+ merged_model = peft_model.merge_and_unload()
402
+
403
+ # --- FIXED: Safe Model Swap using model_lock ---
404
+ with model_lock:
405
+ print("Hot-swapping live model...")
406
+ # Create a new instance and copy state, preserving the object reference
407
+ new_state_dict = merged_model.state_dict()
408
+ model.load_state_dict(new_state_dict)
409
+ model.to(device).eval()
410
+
411
+ date_str = datetime.now().strftime("%Y%m%d-%H%M")
412
+ new_repo_id = f"{TUNED_MODEL_REPO_OWNER}/Sam-2.5-PUBLIC-RLHF-{date_str}"
413
+
414
+ print(f"Saving and uploading tuned model to {new_repo_id}...")
415
+
416
+ # Create a directory to save the model
417
+ local_dir = f"./{new_repo_id.split('/')[-1]}"
418
+ os.makedirs(local_dir, exist_ok=True)
419
+
420
+ # Save model using Hugging Face format
421
+ merged_model.save_pretrained(local_dir, safe_serialization=False)
422
+ tokenizer.save_pretrained(local_dir)
423
+
424
+ # Push to Hub
425
+ from huggingface_hub import HfApi
426
+ api = HfApi()
427
+ api.create_repo(repo_id=new_repo_id, repo_type="model", exist_ok=True)
428
+ api.upload_folder(
429
+ folder_path=local_dir,
430
+ repo_id=new_repo_id,
431
+ repo_type="model"
432
+ )
433
+
434
+ # Clean up local files
435
+ import shutil
436
+ shutil.rmtree(local_dir)
437
+
438
+ print("Upload and hot-swap complete!")
439
+ TRAINING_STATUS = "βœ… Sam-2.5 has been successfully upgraded! Thank you. You have helped shaped the newest generation of sam 2.5 pro solver. You, helped make AI"
440
+ time.sleep(5)
441
+
442
+ except Exception as e:
443
+ print(f"An error occurred during the tuning process: {e}")
444
+ traceback.print_exc()
445
+ TRAINING_STATUS = f"An error occurred during training: {e}"
446
+ time.sleep(10)
447
+ finally:
448
+ TRAINING_STATUS = ""
449
+ training_lock.release()
450
+ print("--- PyTorch Fine-Tuning Task Finished ---")
451
+
452
+ # -------------------------------
453
+ # 7) UI Functions & Gradio Interface
454
+ # -------------------------------
455
+ def check_training_status():
456
+ global TRAINING_STATUS
457
+ if TRAINING_STATUS:
458
+ return gr.update(value=TRAINING_STATUS, visible=True)
459
+ else:
460
+ return gr.update(value="", visible=False)
461
+
462
+ def poll_status_updater():
463
+ while True:
464
+ yield check_training_status()
465
+ time.sleep(1)
466
+
467
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue")) as demo:
468
+ gr.Markdown("""
469
+ # Sam-2.5-PRO-SOLVER-V2 Chat
470
+ A self-improving chatbot powered by Sam-2. Use the thumb icons to rate responses!
471
+ The model automatically fine-tunes on your positive feedback and gets smarter live.
472
+ """)
473
+
474
+ training_status_md = gr.Markdown(value="", visible=False)
475
+ chatbot = gr.Chatbot(label="Sam-2", bubble_full_width=False)
476
+ chat_interface = gr.ChatInterface(fn=predict, chatbot=chatbot)
477
+ chatbot.like(log_feedback, inputs=[chatbot], outputs=None)
478
+
479
+ demo.load(poll_status_updater, None, training_status_md)
480
+ TRAINING_STATUS = "Not training yet. Waiting for more examples. Sam-2.5 is ready."
481
+ if __name__ == "__main__":
482
+ print("Starting Gradio app. Tuning will be triggered by user feedback.")
483
+ demo.launch(show_api=True)