rahul7star commited on
Commit
cf0b63e
Β·
verified Β·
1 Parent(s): 67d08a5

Create app_gpu.py

Browse files
Files changed (1) hide show
  1. app_gpu.py +307 -0
app_gpu.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # universal_lora_trainer_accelerate_singlefile_dynamic.py
2
+ """
3
+ Universal Dynamic LoRA Trainer (Accelerate + PEFT + Gradio)
4
+ - Gemma LLM default
5
+ - Robust batch handling (fixes KeyError: 0)
6
+ - Streams logs to Gradio (includes progress %)
7
+ - Supports CSV/Parquet HuggingFace or local datasets
8
+ """
9
+ import spaces
10
+ import torch
11
+ from huggingface_hub import create_repo, upload_folder
12
+
13
+ import os
14
+ import torch
15
+ import gradio as gr
16
+ import pandas as pd
17
+ import numpy as np
18
+ from pathlib import Path
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from peft import LoraConfig, get_peft_model
21
+ from accelerate import Accelerator
22
+ from huggingface_hub import hf_hub_download, create_repo, upload_folder
23
+
24
+ # transformers optional
25
+ try:
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ TRANSFORMERS_AVAILABLE = True
28
+ except Exception:
29
+ TRANSFORMERS_AVAILABLE = False
30
+
31
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ # ---------------- Helpers ----------------
34
+ def is_hub_repo_like(s):
35
+ return "/" in s and not Path(s).exists()
36
+
37
+ def download_from_hf(repo_id, filename, token=None):
38
+ token = token or os.environ.get("HF_TOKEN")
39
+ return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
40
+
41
+ # ---------------- Dataset ----------------
42
+ class MediaTextDataset(Dataset):
43
+ def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
44
+ self.is_hub = is_hub_repo_like(source)
45
+ token = os.environ.get("HF_TOKEN")
46
+ if self.is_hub:
47
+ file_path = download_from_hf(source, csv_name, token)
48
+ else:
49
+ file_path = Path(source) / csv_name
50
+
51
+ # fallback to parquet if CSV missing
52
+ if not Path(file_path).exists():
53
+ alt = Path(str(file_path).replace(".csv", ".parquet"))
54
+ if alt.exists():
55
+ file_path = alt
56
+ else:
57
+ raise FileNotFoundError(f"Dataset file not found: {file_path}")
58
+
59
+ self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
60
+ if max_records:
61
+ self.df = self.df.head(max_records)
62
+
63
+ self.text_columns = text_columns or ["short_prompt", "long_prompt"]
64
+
65
+ print(f"[DEBUG] Loaded dataset: {file_path}, columns: {list(self.df.columns)}")
66
+ print(f"[DEBUG] Sample rows:\n{self.df.head(3)}")
67
+
68
+ def __len__(self):
69
+ return len(self.df)
70
+
71
+ def __getitem__(self, i):
72
+ rec = self.df.iloc[i]
73
+ out = {"text": {}}
74
+ for col in self.text_columns:
75
+ out["text"][col] = rec[col] if col in rec else ""
76
+ return out
77
+
78
+ # ---------------- Model loader ----------------
79
+ def load_pipeline_auto(base_model, dtype=torch.float16):
80
+ if "gemma" in base_model.lower():
81
+ if not TRANSFORMERS_AVAILABLE:
82
+ raise RuntimeError("Transformers not installed for LLM support.")
83
+ print(f"[INFO] Using Gemma LLM for {base_model}")
84
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
85
+ model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
86
+ return {"model": model, "tokenizer": tokenizer}
87
+ else:
88
+ raise NotImplementedError("Only Gemma LLM supported in this script.")
89
+
90
+ def find_target_modules(model):
91
+ candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
92
+ names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
93
+ targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)]
94
+ if not targets:
95
+ targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
96
+ print(f"[WARNING] No standard attention modules found, using Linear layers for LoRA.")
97
+ else:
98
+ print(f"[INFO] LoRA target modules detected: {targets[:40]}{'...' if len(targets)>40 else ''}")
99
+ return targets
100
+
101
+ # ---------------- Batch unwrapping ----------------
102
+ def unwrap_batch(batch, short_col, long_col):
103
+ if isinstance(batch, (list, tuple)):
104
+ ex = batch[0]
105
+ if "text" in ex:
106
+ return ex
107
+ if "short" in ex and "long" in ex:
108
+ return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}}
109
+ return {"text": ex}
110
+
111
+ if isinstance(batch, dict):
112
+ first_elem = {}
113
+ is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values())
114
+ if is_batched:
115
+ for k, v in batch.items():
116
+ try: first = v[0]
117
+ except Exception: first = v
118
+ first_elem[k] = first
119
+ if "text" in first_elem:
120
+ t = first_elem["text"]
121
+ if isinstance(t, (list, tuple)) and len(t) > 0:
122
+ return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}}
123
+ if isinstance(t, dict): return {"text": t}
124
+ return {"text": {short_col: str(t), long_col: ""}}
125
+ if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem):
126
+ s = first_elem.get(short_col, first_elem.get("short", ""))
127
+ l = first_elem.get(long_col, first_elem.get("long", ""))
128
+ return {"text": {short_col: str(s), long_col: str(l)}}
129
+ return {"text": {short_col: str(first_elem)}}
130
+ if "text" in batch and isinstance(batch["text"], dict):
131
+ return {"text": batch["text"]}
132
+ s = batch.get(short_col, batch.get("short", ""))
133
+ l = batch.get(long_col, batch.get("long", ""))
134
+ return {"text": {short_col: str(s), long_col: str(l)}}
135
+ return {"text": {short_col: str(batch), long_col: ""}}
136
+
137
+ # ---------------- Training (forward + backward + logs) ----------------
138
+ import spaces
139
+ import torch
140
+ from huggingface_hub import create_repo, upload_folder
141
+
142
+ @spaces.GPU(duration=120)
143
+ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
144
+ epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0,
145
+ max_train_records=None, repo_id=None):
146
+ """LoRA training loop with GPU + auto upload support."""
147
+
148
+ # --- Device setup ---
149
+ device = "cuda" if torch.cuda.is_available() else "cpu"
150
+ gpu_name = torch.cuda.get_device_name(0) if device == "cuda" else "CPU"
151
+ print(f"[INFO] πŸš€ Using device: {device.upper()} ({gpu_name})")
152
+
153
+ # Adjust precision / batch based on VRAM
154
+ if device == "cuda":
155
+ vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
156
+ print(f"[INFO] VRAM: {vram:.2f} GB")
157
+ dtype = torch.bfloat16 if "A100" in gpu_name or vram > 20 else torch.float16
158
+ if vram < 10:
159
+ batch_size = max(1, batch_size // 2)
160
+ print(f"[WARN] Low VRAM, using batch_size={batch_size}")
161
+ else:
162
+ dtype = torch.float32
163
+
164
+ # --- Model & tokenizer ---
165
+ accelerator = Accelerator()
166
+ pipe = load_pipeline_auto(base_model, dtype=dtype)
167
+ model_obj = pipe["model"]
168
+ tokenizer = pipe["tokenizer"]
169
+
170
+ model_obj.train()
171
+ target_modules = find_target_modules(model_obj)
172
+ lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
173
+ lora_module = get_peft_model(model_obj, lcfg)
174
+
175
+ # --- Dataset ---
176
+ dataset = MediaTextDataset(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
177
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
178
+ optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr)
179
+ lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader)
180
+
181
+ total_steps = max(1, epochs * len(loader))
182
+ step_counter = 0
183
+ logs = []
184
+
185
+ yield f"[INFO] Starting LoRA training on {gpu_name}...\n", 0.0
186
+
187
+ # --- Training Loop ---
188
+ for ep in range(epochs):
189
+ yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / total_steps
190
+ for i, batch in enumerate(loader):
191
+ ex = unwrap_batch(batch, text_cols[0], text_cols[1])
192
+ texts = ex.get("text", {})
193
+ short_text = str(texts.get(text_cols[0], "") or "")
194
+ long_text = str(texts.get(text_cols[1], "") or "")
195
+
196
+ enc = tokenizer(
197
+ short_text,
198
+ text_pair=long_text,
199
+ return_tensors="pt",
200
+ padding="max_length",
201
+ truncation=True,
202
+ max_length=512,
203
+ )
204
+ enc = {k: v.to(accelerator.device) for k, v in enc.items()}
205
+ enc["labels"] = enc["input_ids"].clone()
206
+
207
+ outputs = lora_module(**enc)
208
+ loss = getattr(outputs, "loss", None)
209
+ if loss is None:
210
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
211
+ loss = torch.nn.functional.cross_entropy(
212
+ logits.view(-1, logits.size(-1)),
213
+ enc["labels"].view(-1),
214
+ ignore_index=tokenizer.pad_token_id
215
+ )
216
+
217
+ optimizer.zero_grad()
218
+ accelerator.backward(loss)
219
+ optimizer.step()
220
+
221
+ logs.append(f"[DEBUG] Step {step_counter}, Loss: {loss.item():.6f}")
222
+ step_counter += 1
223
+ yield "\n".join(logs[-10:]), step_counter / total_steps
224
+
225
+ # --- Save LoRA ---
226
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
227
+ lora_module.save_pretrained(output_dir)
228
+ yield f"[INFO] βœ… LoRA saved to {output_dir}\n", 0.95
229
+
230
+ # --- Auto Upload to HF ---
231
+ repo_id = repo_id or os.environ.get("HF_UPLOAD_REPO")
232
+ token = os.environ.get("HF_TOKEN")
233
+
234
+ if repo_id and token:
235
+ yield f"[INFO] Uploading adapter to {repo_id}...\n", 0.97
236
+ try:
237
+ create_repo(repo_id, repo_type="model", exist_ok=True, token=token)
238
+ upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model", token=token)
239
+ yield f"[INFO] βœ… Uploaded successfully: https://huggingface.co/{repo_id}\n", 1.0
240
+ except Exception as e:
241
+ yield f"[ERROR] Upload failed: {e}\n", 1.0
242
+ else:
243
+ yield f"[INFO] Skipping upload β€” repo_id or token not provided.\n", 1.0
244
+
245
+
246
+ def upload_adapter(local, repo_id):
247
+ token = os.environ.get("HF_TOKEN")
248
+ if not token:
249
+ raise RuntimeError("HF_TOKEN missing")
250
+ create_repo(repo_id, exist_ok=True)
251
+ upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
252
+ return f"https://huggingface.co/{repo_id}"
253
+
254
+ # ---------------- Gradio UI ----------------
255
+ def run_ui():
256
+ with gr.Blocks() as demo:
257
+ gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Gemma LLM)")
258
+
259
+ with gr.Row():
260
+ base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
261
+ dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
262
+ csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
263
+ short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
264
+ long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
265
+ out = gr.Textbox(label="Output dir", value="./adapter_out")
266
+ repo = gr.Textbox(label="Upload HF repo (optional)", value="rahul7star/gemma-3-270m-ccebc0")
267
+
268
+ with gr.Row():
269
+ batch_size = gr.Number(value=1, label="Batch size")
270
+ num_workers = gr.Number(value=0, label="DataLoader num_workers")
271
+ r = gr.Number(value=8, label="LoRA rank")
272
+ a = gr.Number(value=16, label="LoRA alpha")
273
+ ep = gr.Number(value=1, label="Epochs")
274
+ lr = gr.Number(value=1e-4, label="Learning rate")
275
+ max_records = gr.Number(value=1000, label="Max training records")
276
+
277
+ logs = gr.Textbox(label="Logs (streaming)", lines=25)
278
+
279
+ def launch(bm, ds, csv, sc, lc, out_dir, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
280
+ gen = train_lora_stream(
281
+ bm, ds, csv, [sc, lc], out_dir,
282
+ epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_),
283
+ batch_size=int(batch), num_workers=int(num_w),
284
+ max_train_records=int(max_rec)
285
+ )
286
+ for item in gen:
287
+ if isinstance(item, tuple):
288
+ text = item[0]
289
+ else:
290
+ text = item
291
+ yield text
292
+
293
+ if repo_:
294
+ link = upload_adapter(out_dir, repo_)
295
+ yield f"[INFO] Uploaded to {link}\n"
296
+
297
+ btn = gr.Button("πŸš€ Start Training")
298
+ btn.click(fn=launch,
299
+ inputs=[base_model, dataset, csvname, short_col, long_col, out,
300
+ batch_size, num_workers, r, a, ep, lr, max_records, repo],
301
+ outputs=[logs],
302
+ queue=True)
303
+
304
+ return demo
305
+
306
+ if __name__ == "__main__":
307
+ run_ui().launch(server_name="0.0.0.0", server_port=7860, share=True)