#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Pretrain Veronica-Polymorphic from scratch (clean mixture: FinePDFs / DCLM / FineWeb-Edu). Basic example: python veronica-polymorphic/scripts/train_veronica.py \ --config veronica-polymorphic/configs/veronica-pretrain-12L.json \ --dataset_paths data/mix_optimal_50_30_20_2048 \ --output_dir veronica-polymorphic/runs/veronica-pretrain-vMix-2048 \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 4 \ --learning_rate 2e-4 \ --label_smoothing 0.01 \ --rep_alpha 0.0 \ --max_steps 60000 \ --max_seq_len 2048 You can use different datasets (e.g., 512 / 1024 / 2048) in separate runs for length curriculum. """ import os import re import glob import json import math import argparse import random from dataclasses import dataclass from typing import Dict, List, Optional import torch import torch.nn.functional as F from datasets import load_from_disk from transformers import ( AutoTokenizer, Trainer, TrainingArguments, TrainerCallback, CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, LogitsProcessorList, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor, ) # --- Veronica bindings --- from veronica.configuration_veronica import VeronicaConfig from veronica.modeling_veronica import VeronicaForCausalLM from veronica.modeling_components import Fp32LayerNorm CONFIG_MAPPING.register("veronica", VeronicaConfig) MODEL_FOR_CAUSAL_LM_MAPPING.register(VeronicaConfig, VeronicaForCausalLM) # Disable CUDA Graphs (HF Trainer + torch.compile may conflict sometimes) os.environ.setdefault("TORCH_COMPILE_USE_CUDAGRAPHS", "0") os.environ.setdefault("TORCHINDUCTOR_DISABLE_CUDAGRAPHS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # =========================== # Utility # =========================== def find_latest_checkpoint(run_dir: str) -> Optional[str]: ckpts = glob.glob(os.path.join(run_dir, "checkpoint-*")) if not ckpts: return None ckpts.sort(key=lambda p: int(re.search(r"checkpoint-(\d+)", p).group(1))) return ckpts[-1] def build_tokenizer(candidates: List[str], save_dir: str) -> AutoTokenizer: """ Try to load an existing tokenizer from the provided paths; otherwise fallback to gpt2 and add basic special tokens. """ tok = None for p in candidates: if os.path.exists(p): try: tok = AutoTokenizer.from_pretrained(p, use_fast=True) print(f"[tokenizer] loaded from {p}") break except Exception: pass if tok is None: print("[tokenizer] fallback: gpt2") tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True) specials: Dict[str, str] = {} if tok.eos_token is None: specials["eos_token"] = "<|eos|>" if tok.pad_token is None: specials["pad_token"] = "<|pad|>" if tok.bos_token is None: specials["bos_token"] = "<|bos|>" if specials: tok.add_special_tokens(specials) tok.save_pretrained(save_dir) tok = AutoTokenizer.from_pretrained(save_dir, use_fast=True) base_vocab = tok.vocab_size effective_vocab = len(tok) print( f"[tokenizer] base_vocab={base_vocab} added={effective_vocab - base_vocab} " f"effective_vocab={effective_vocab} eos={tok.eos_token_id} " f"pad={tok.pad_token_id} bos={tok.bos_token_id}" ) return tok def load_cfg_with_vocab(cfg_path: str, tok: AutoTokenizer) -> VeronicaConfig: """ Load the config and adapt it to the tokenizer vocabulary. Model is designed as UN-TIED (lm_head != wte). """ with open(cfg_path, "r", encoding="utf-8") as f: d = json.load(f) cfg = VeronicaConfig(**d) cfg.model_type = "veronica" cfg.vocab_size = int(len(tok)) # untied model: no tie_word_embeddings return cfg def init_model_from_config(cfg: VeronicaConfig, tok: AutoTokenizer) -> VeronicaForCausalLM: model = VeronicaForCausalLM(cfg) use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(dtype=dtype, device=device) effective_vocab = len(tok) emb = model.get_input_embeddings().weight head = model.lm_head.weight # Align embedding/head to the effective vocab if emb.shape[0] != effective_vocab or head.shape[0] != effective_vocab: old_vocab = emb.shape[0] print(f"[model] resize_token_embeddings: {old_vocab} -> {effective_vocab}") model.resize_token_embeddings(effective_vocab) with torch.no_grad(): new_emb = model.get_input_embeddings().weight new_head = model.lm_head.weight mean_emb = new_emb[:old_vocab].mean(dim=0, keepdim=True) mean_head = new_head[:old_vocab].mean(dim=0, keepdim=True) if effective_vocab > old_vocab: new_emb[old_vocab:] = mean_emb new_head[old_vocab:] = mean_head # Keep LayerNorm params in float32 (after global cast) for m in model.modules(): if isinstance(m, Fp32LayerNorm): m.ln.to(dtype=torch.float32) model.config.use_cache = False n_params = sum(p.numel() for p in model.parameters()) print(f"[model] params={n_params:,} vocab={effective_vocab}") return model def load_mix_dataset(path: str): """ Load a packed dataset (train/validation) from disk. Expected HuggingFace formats: a DatasetDict with 'train' and 'validation', or a single Dataset that gets split 99/1. """ ds = load_from_disk(path) if isinstance(ds, dict) and "train" in ds and "validation" in ds: return ds["train"], ds["validation"] split = ds.train_test_split(test_size=0.01, seed=42) return split["train"], split["test"] # =========================== # Collator # =========================== @dataclass class CausalCollator: tokenizer: AutoTokenizer mask_runs: bool = False run_len: int = 4 max_seq_len: Optional[int] = None # target length (e.g., 512/1024/2048) def _mask_degenerate_runs(self, labels: torch.Tensor): """ Mask degenerate runs (e.g., '____', '....') with length >= run_len. Mostly legacy; can be left off with a clean dataset. """ try: id_us = self.tokenizer.encode("_", add_special_tokens=False)[0] id_dot = self.tokenizer.encode(".", add_special_tokens=False)[0] except Exception: return B, T = labels.size() for b in range(B): cnt_u = cnt_d = 0 for t in range(T): tok = int(labels[b, t].item()) if tok == id_us: cnt_u += 1 cnt_d = 0 elif tok == id_dot: cnt_d += 1 cnt_u = 0 else: cnt_u = cnt_d = 0 if cnt_u >= self.run_len or cnt_d >= self.run_len: labels[b, t] = -100 def _crop(self, ids: torch.Tensor) -> torch.Tensor: """ If max_seq_len is set and the sequence is longer, crop a random window of length max_seq_len. """ if self.max_seq_len is None: return ids L = ids.size(0) if L <= self.max_seq_len: return ids start = random.randint(0, L - self.max_seq_len) end = start + self.max_seq_len return ids[start:end] def __call__(self, features): ids_list = [] for f in features: ids = torch.tensor(f["input_ids"], dtype=torch.long) ids = self._crop(ids) ids_list.append(ids) pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id ids = torch.nn.utils.rnn.pad_sequence(ids_list, batch_first=True, padding_value=pad_id) attn = torch.where(ids == pad_id, 0, 1) labels = ids.clone() labels[labels == pad_id] = -100 if self.mask_runs: self._mask_degenerate_runs(labels) return {"input_ids": ids, "attention_mask": attn, "labels": labels} # =========================== # Callback Router + Smoke eval # =========================== SMOKE_PROMPTS = [ "The world we live in today is", "Understanding complex ideas requires", "Human intelligence differs from artificial intelligence because", "A good system design is based on", "In the middle of every difficulty lies", "Once upon a time, there was a scientist who", ] class RouterAndSmokeCallback(TrainerCallback): def __init__(self, tok: AutoTokenizer): self.tok = tok def on_log(self, args, state, control, **kwargs): model = kwargs.get("model", None) if model is None: return try: if hasattr(model, "router_alpha_mean") and model.router_alpha_mean is not None: alpha = model.router_alpha_mean.detach().float().cpu() p = alpha / alpha.sum() ent = -(p * (p.clamp_min(1e-9)).log()).sum() ent_norm = float(ent / math.log(len(p))) print(f"[router] alpha={alpha.tolist()} entropy_norm={ent_norm:.4f}") except Exception: pass def on_evaluate(self, args, state, control, **kwargs): model = kwargs.get("model", None) if model is None: return model.eval() dev = next(model.parameters()).device prompt = random.choice(SMOKE_PROMPTS) ids = self.tok(prompt, return_tensors="pt").to(dev) processors = LogitsProcessorList([ NoRepeatNGramLogitsProcessor(3), RepetitionPenaltyLogitsProcessor(1.1), ]) with torch.no_grad(): out = model.generate( **ids, max_new_tokens=64, do_sample=False, logits_processor=processors, eos_token_id=self.tok.eos_token_id, pad_token_id=(self.tok.pad_token_id or self.tok.eos_token_id), use_cache=True, ) txt = self.tok.decode(out[0], skip_special_tokens=True) completion = txt[len(prompt):].strip() if txt.startswith(prompt) else txt print(f"\n[SMOKE] {prompt} → {completion}\n") model.train() # =========================== # Callback schedule router_tau / aux_weight # =========================== class RouterScheduleCallback(TrainerCallback): """ Linearly schedule router_tau and router_aux_weight between start and end of training. """ def __init__( self, tau_start: float, tau_end: float, aux_start: float, aux_end: float, total_steps: int, tau_freeze_steps: int = 0, force_prob: float = 0.0, force_warmup_steps: int = 0, ): self.tau_start = float(tau_start) self.tau_end = float(tau_end) self.aux_start = float(aux_start) self.aux_end = float(aux_end) self.total_steps = max(int(total_steps), 1) self.tau_freeze_steps = max(int(tau_freeze_steps), 0) self.force_prob = float(force_prob) self.force_warmup_steps = max(int(force_warmup_steps), 0) def _interp(self, start: float, end: float, step: int, span: int) -> float: t = min(max(step, 0), span) alpha = t / float(max(span, 1)) return (1.0 - alpha) * start + alpha * end def on_step_begin(self, args, state, control, **kwargs): model = kwargs.get("model", None) if model is None: return step = state.global_step # Tau: keep frozen for tau_freeze_steps, then interpolate over the remaining span if step < self.tau_freeze_steps: new_tau = self.tau_start else: rem_step = step - self.tau_freeze_steps rem_span = max(self.total_steps - self.tau_freeze_steps, 1) new_tau = self._interp(self.tau_start, self.tau_end, rem_step, rem_span) # Aux always interpolates across total training steps new_aux = self._interp(self.aux_start, self.aux_end, step, self.total_steps) # update global config if hasattr(model, "config"): model.config.router_tau = new_tau model.config.router_aux_weight = new_aux # update all block.mlp (PolymorphicMLP must use router_tau in forward) for block in getattr(model, "blocks", []): if hasattr(block, "mlp"): # default: no forcing unless scheduled below block.mlp.router_tau = new_tau block.mlp.force_func = -1 # During early warmup, occasionally force a single branch so all get gradients if step < self.force_warmup_steps and self.force_prob > 0.0: if random.random() < self.force_prob: for block in getattr(model, "blocks", []): if hasattr(block, "mlp") and hasattr(block.mlp, "num_funcs"): k = block.mlp.num_funcs block.mlp.force_func = random.randint(0, max(k - 1, 0)) if step % 1000 == 0: print( f"[router-sched] step={step} tau={new_tau:.4f} aux_w={new_aux:.5f} " f"freeze<= {self.tau_freeze_steps} force_p={self.force_prob:.3f} warmup<= {self.force_warmup_steps}" ) # =========================== # Custom Trainer with rep_loss # =========================== class VeronicaTrainer(Trainer): def __init__(self, *args, label_smoothing: float = 0.0, rep_alpha: float = 0.0, **kwargs): super().__init__(*args, **kwargs) self.label_smoothing = float(label_smoothing) self.rep_alpha = float(rep_alpha) def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.get("labels") if labels is None: raise ValueError("compute_loss called without labels") model_inputs = {k: v for k, v in inputs.items() if k != "labels"} outputs = model(**model_inputs) logits = outputs.logits # [B, T, V] ignore_index = -100 # SHIFT: predict x_{t+1} shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() valid_mask = (shift_labels != ignore_index) safe_labels = shift_labels.clone() safe_labels[~valid_mask] = 0 log_probs = F.log_softmax(shift_logits, dim=-1) # [B, T-1, V] nll_full = -log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) nll_loss = nll_full[valid_mask].mean() if self.label_smoothing > 0.0: smooth_full = -log_probs.mean(dim=-1) smooth_loss = smooth_full[valid_mask].mean() ce_loss = (1.0 - self.label_smoothing) * nll_loss + self.label_smoothing * smooth_loss else: ce_loss = nll_loss total_loss = ce_loss # rep_loss on x_{t+1} when x_{t+1} == x_t if self.rep_alpha > 0.0: labels_prev = labels[:, :-1] # x_t labels_next = shift_labels # x_{t+1} valid_prev = (labels_prev != ignore_index) same_mask = valid_prev & valid_mask & (labels_prev == labels_next) if same_mask.any(): rep_logp = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) rep_p = rep_logp[same_mask].exp() total_loss = total_loss + self.rep_alpha * rep_p.mean() # aux_loss del router: SUBTRACT to MAXIMIZE entropy (prevent collapse) aux_loss = getattr(model, "_last_router_aux", None) if aux_loss is not None and hasattr(model, "config"): aux_w = float(getattr(model.config, "router_aux_weight", 0.0)) if aux_w > 0: if not torch.is_tensor(aux_loss): aux_loss = torch.as_tensor(aux_loss, device=logits.device, dtype=logits.dtype) # Subtract aux (entropy) so that minimizing loss => maximize entropy => soft router total_loss = total_loss - aux_w * aux_loss.clamp_min(0.0) return (total_loss, outputs) if return_outputs else total_loss # =========================== # Main # =========================== def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--dataset_paths", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True, default="veronica-polymorphic/runs/veronica-pretrain") parser.add_argument( "--tokenizer_candidates", type=str, nargs="*", default=["veronica-polymorphic/tokenizer", "gpt2"], ) parser.add_argument( "--tokenizer_out", type=str, default="veronica-polymorphic/tokenizer_vmix", ) parser.add_argument("--per_device_train_batch_size", type=int, default=4) parser.add_argument("--per_device_eval_batch_size", type=int, default=4) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--max_steps", type=int, default=60000) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--warmup_ratio", type=float, default=0.02) parser.add_argument("--weight_decay", type=float, default=0.1) parser.add_argument("--eval_steps", type=int, default=1000) parser.add_argument("--save_steps", type=int, default=1000) parser.add_argument("--logging_steps", type=int, default=100) parser.add_argument("--label_smoothing", type=float, default=0.01) parser.add_argument("--rep_alpha", type=float, default=0.0) parser.add_argument("--mask_degenerate_runs", action="store_true") parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--resume_from", type=str, default=None, help="Checkpoint to resume from (e.g., .../checkpoint-22000)", ) parser.add_argument( "--max_seq_len", type=int, default=None, help="Maximum window length (e.g., 512, 1024, 2048). If None, uses the full dataset sequence.", ) # Schedule router parser.add_argument("--router_tau_start", type=float, default=1.6) parser.add_argument("--router_tau_end", type=float, default=1.1) parser.add_argument("--router_aux_start", type=float, default=0.005) parser.add_argument("--router_aux_end", type=float, default=0.012) parser.add_argument("--router_tau_freeze_steps", type=int, default=4000, help="Keep tau constant for the first N steps to avoid early specialization.") parser.add_argument("--router_force_prob", type=float, default=0.05, help="Per-step probability to force a single branch during warmup.") parser.add_argument("--router_force_warmup_steps", type=int, default=3000, help="Apply random branch forcing only within these initial steps.") args = parser.parse_args() # Tokenizer tok = build_tokenizer(args.tokenizer_candidates, args.tokenizer_out) # Config & Model cfg = load_cfg_with_vocab(args.config, tok) cfg.router_tau = args.router_tau_start cfg.router_aux_weight = args.router_aux_start model = init_model_from_config(cfg, tok) # Diagnostics: verify model forward loss model.eval() with torch.no_grad(): dummy = torch.randint(0, model.config.vocab_size, (1, 32), device=model.device) out = model(input_ids=dummy, labels=dummy) loss_model = out.loss.item() logits = out.logits # [1, 32, V] shift_logits = logits[:, :-1, :].contiguous() shift_labels = dummy[:, 1:].contiguous() loss_manual = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ).item() print(f"[diag] loss_model_forward={loss_model:.4f} loss_manual_shift={loss_manual:.4f}") model.train() # Dataset train_ds, val_ds = load_mix_dataset(args.dataset_paths) collator = CausalCollator( tokenizer=tok, mask_runs=args.mask_degenerate_runs, max_seq_len=args.max_seq_len, ) # Resume resume_ckpt = args.resume_from or find_latest_checkpoint(args.output_dir) if resume_ckpt: print(f"🟢 Resuming from: {resume_ckpt}") else: print("⚪ No checkpoint: training from scratch.") use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() train_args = TrainingArguments( output_dir=args.output_dir, run_name=os.path.basename(args.output_dir.rstrip("/")), num_train_epochs=1_000, # guidato da max_steps max_steps=args.max_steps, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, lr_scheduler_type="cosine", logging_steps=args.logging_steps, eval_steps=args.eval_steps, save_steps=args.save_steps, eval_strategy="steps", # ✅ save_total_limit=5, bf16=use_bf16, fp16=(torch.cuda.is_available() and not use_bf16), gradient_checkpointing=True, report_to=["tensorboard"], dataloader_num_workers=2, seed=args.seed, label_smoothing_factor=0.0, # smoothing gestito in compute_loss custom max_grad_norm=1.0, save_safetensors=False, ) callbacks: List[TrainerCallback] = [ RouterAndSmokeCallback(tok), RouterScheduleCallback( tau_start=args.router_tau_start, tau_end=args.router_tau_end, aux_start=args.router_aux_start, aux_end=args.router_aux_end, total_steps=args.max_steps, tau_freeze_steps=args.router_tau_freeze_steps, force_prob=args.router_force_prob, force_warmup_steps=args.router_force_warmup_steps, ), ] trainer = VeronicaTrainer( model=model, args=train_args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=tok, # ✅ al posto di processing_class data_collator=collator, callbacks=callbacks, label_smoothing=args.label_smoothing, rep_alpha=args.rep_alpha, ) # Sanity check: vocab/emb/head effective_vocab = len(tok) emb = model.get_input_embeddings().weight head = model.lm_head.weight assert emb.shape[0] == effective_vocab == head.shape[0], "Mismatch vocab/emb/lm_head" # Train trainer.train(resume_from_checkpoint=resume_ckpt) trainer.save_state() trainer.save_model(args.output_dir) tok.save_pretrained(args.output_dir) with open(os.path.join(args.output_dir, "config.final.json"), "w", encoding="utf-8") as f: json.dump(model.config.to_dict(), f, indent=2) print("✅ Pretraining completed/saved.") if __name__ == "__main__": main()