from typing import Optional, Tuple, List import math import torch import torch.nn as nn import torch.nn.functional as F def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Applica Rotary Positional Embeddings (RoPE) a query e key. Args: q: Query tensor di shape (B, nh, T, hd) k: Key tensor di shape (B, nh, T, hd) cos: Cosine values di shape (1, 1, T, hd) sin: Sine values di shape (1, 1, T, hd) Returns: Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated) """ # Dividi le dimensioni in metà per rotazione complessa # q, k: (B, nh, T, hd) -> split in (B, nh, T, hd/2) pairs hd = q.shape[-1] assert hd % 2 == 0, "head_dim deve essere pari per RoPE" q1, q2 = q[..., :hd//2], q[..., hd//2:] k1, k2 = k[..., :hd//2], k[..., hd//2:] # Applica rotazione: [cos*q1 - sin*q2, sin*q1 + cos*q2] cos_half = cos[..., :hd//2] sin_half = sin[..., :hd//2] q_rot = torch.cat([ q1 * cos_half - q2 * sin_half, q1 * sin_half + q2 * cos_half ], dim=-1) k_rot = torch.cat([ k1 * cos_half - k2 * sin_half, k1 * sin_half + k2 * cos_half ], dim=-1) return q_rot, k_rot def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor: """ Entropia media della distribuzione alpha sui K rami. alpha: (B, T, K) Ritorna entropia normalizzata in [0, 1] circa. """ if alpha is None: return torch.tensor(0.0, device="cpu") eps = 1e-9 k = alpha.size(-1) ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1) # (B, T) norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device))) return norm_ent.mean() class DepthwiseCausalConv1d(nn.Module): """ Depthwise 1D causal convolution sulla dimensione di sequenza. Input: (B, T, H) -> output: (B, T, H) groups=H per avere un filtro per canale. """ def __init__(self, channels: int, kernel_size: int = 3): super().__init__() assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd" self.kernel_size = kernel_size self.pad = kernel_size - 1 self.conv = nn.Conv1d( in_channels=channels, out_channels=channels, kernel_size=kernel_size, padding=0, groups=channels, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, H) -> (B, H, T) x_c = x.transpose(1, 2) # left-pad con zeri per causalità x_c = F.pad(x_c, (self.pad, 0)) y = self.conv(x_c) y = y.transpose(1, 2) return y class ChannelAttention(nn.Module): """ Attenzione per-canale (tipo SE) per token. """ def __init__(self, channels: int, reduction: int = 4): super().__init__() hidden = max(channels // reduction, 1) self.ln = nn.LayerNorm(channels) self.fc1 = nn.Linear(channels, hidden) self.fc2 = nn.Linear(hidden, channels) def forward(self, x: torch.Tensor) -> torch.Tensor: g = self.ln(x) g = F.gelu(self.fc1(g)) g = torch.sigmoid(self.fc2(g)) return x * g class Fp32LayerNorm(nn.Module): """ LayerNorm in float32 per stabilità numerica, castando avanti/indietro. I parametri rimangono in float32. """ def __init__(self, normalized_shape: int, eps: float = 1e-5): super().__init__() self.ln = nn.LayerNorm(normalized_shape, eps=eps) self.ln.to(dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: orig_dtype = x.dtype # Disable autocast to prevent BF16/FP16 from being injected into LayerNorm if x.is_cuda: with torch.autocast(device_type="cuda", enabled=False): y = self.ln(x.to(torch.float32)) else: with torch.autocast(device_type="cpu", enabled=False): y = self.ln(x.to(torch.float32)) return y.to(orig_dtype) # --- Rami base per il PolymorphicMLP --- class SwigluMLP(nn.Module): def __init__(self, hidden_size: int, mlp_mult: float): super().__init__() mlp_dim = int(round(mlp_mult * hidden_size)) self.mlp_dim = mlp_dim self.up = nn.Linear(hidden_size, 2 * mlp_dim) self.down = nn.Linear(mlp_dim, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up(x) a, b = up.split(self.mlp_dim, dim=-1) y = F.silu(a) * b return self.down(y) class GluMLP(nn.Module): def __init__(self, hidden_size: int, mlp_mult: float): super().__init__() mlp_dim = int(round(mlp_mult * hidden_size)) self.mlp_dim = mlp_dim self.up = nn.Linear(hidden_size, 2 * mlp_dim) self.down = nn.Linear(mlp_dim, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: up = self.up(x) a, b = up.split(self.mlp_dim, dim=-1) y = torch.sigmoid(a) * b return self.down(y) class DepthwiseConvBranch(nn.Module): def __init__(self, hidden_size: int, mlp_mult: float = 4.0): super().__init__() mlp_dim = int(round(mlp_mult * hidden_size)) self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3) self.expand = nn.Linear(hidden_size, mlp_dim) self.act = nn.GELU() self.contract = nn.Linear(mlp_dim, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.dw(x) y = self.expand(y) y = self.act(y) return self.contract(y) class PolymorphicMLP(nn.Module): """ MLP polimorfico: - Router: produce alpha (B, T, K) - K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv) - Output: somma pesata dei rami - Opzionale ChannelAttention - Espone: - last_alpha (B, T, K) per logging - last_aux (entropia normalizzata media) per aux-loss - force_func: se >= 0, forza un solo ramo (debug / training per ramo) """ def __init__( self, hidden_size: int, mlp_mult: float = 4.0, num_funcs: int = 3, router_dim: Optional[int] = None, dropout: float = 0.0, use_channel_attention: bool = False, router_tau: float = 1.0, ): super().__init__() assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base" self.hidden_size = hidden_size self.mlp_mult = mlp_mult self.num_funcs = num_funcs # Router r_dim = router_dim or hidden_size self.router = nn.Sequential( nn.Linear(hidden_size, r_dim), nn.GELU(), nn.Linear(r_dim, num_funcs), ) self.router_tau = router_tau # Inizializza router con pesi piccoli per distribuzioni più uniformi for m in self.router.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) # Rami base (primi 3: compatibili con la tua v1) funcs: List[nn.Module] = [] if num_funcs >= 1: funcs.append(SwigluMLP(hidden_size, mlp_mult)) if num_funcs >= 2: funcs.append(GluMLP(hidden_size, mlp_mult)) if num_funcs >= 3: funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult)) # Se in futuro alzi num_funcs > 3, dovrai aggiungere nuovi rami qui # (es. un MLP più profondo, un branch più conv-heavy, ecc.) while len(funcs) < num_funcs: funcs.append(SwigluMLP(hidden_size, mlp_mult)) # fallback: extra-swiglu self.funcs = nn.ModuleList(funcs) self.dropout = nn.Dropout(dropout) self.use_channel_attention = use_channel_attention self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None # Monitoring self.last_alpha: Optional[torch.Tensor] = None self.last_aux: Optional[torch.Tensor] = None # Forzatura di un singolo ramo (es. per debug / fasi speciali) self.force_func: int = -1 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Router: (B, T, H) -> (B, T, K) logits = self.router(x) tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0 alpha = F.softmax(logits / tau, dim=-1) # Forza un solo ramo se richiesto if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs: one_hot = torch.zeros_like(alpha) one_hot[..., self.force_func] = 1.0 alpha = one_hot # Rami ys = [f(x) for f in self.funcs] # lista di (B, T, H) y_stack = torch.stack(ys, dim=2) # (B, T, K, H) alpha_exp = alpha.unsqueeze(-1) # (B, T, K, 1) y = (alpha_exp * y_stack).sum(dim=2) # (B, T, H) if self.use_channel_attention and self.chan_attn is not None: y = self.chan_attn(y) y = self.dropout(y) # Monitoring self.last_alpha = alpha.detach() self.last_aux = None if self.training: # Token-level entropy encourages mixing at each position token_ent = router_aux_loss(alpha) # Global entropy over mean usage encourages balanced branch usage overall p = alpha.mean(dim=(0, 1)) # (K,) ent = -(p * (p + 1e-9).log()).sum() k = alpha.size(-1) global_ent = ent / math.log(float(k)) # Combine both to stabilize early training and avoid collapse self.last_aux = 0.5 * token_ent + 0.5 * global_ent return y, alpha