Veronica / veronica /modeling_components.py
MhaWay's picture
HF alignment
342a5c0
from typing import Optional, Tuple, List
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applica Rotary Positional Embeddings (RoPE) a query e key.
Args:
q: Query tensor di shape (B, nh, T, hd)
k: Key tensor di shape (B, nh, T, hd)
cos: Cosine values di shape (1, 1, T, hd)
sin: Sine values di shape (1, 1, T, hd)
Returns:
Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated)
"""
# Dividi le dimensioni in metà per rotazione complessa
# q, k: (B, nh, T, hd) -> split in (B, nh, T, hd/2) pairs
hd = q.shape[-1]
assert hd % 2 == 0, "head_dim deve essere pari per RoPE"
q1, q2 = q[..., :hd//2], q[..., hd//2:]
k1, k2 = k[..., :hd//2], k[..., hd//2:]
# Applica rotazione: [cos*q1 - sin*q2, sin*q1 + cos*q2]
cos_half = cos[..., :hd//2]
sin_half = sin[..., :hd//2]
q_rot = torch.cat([
q1 * cos_half - q2 * sin_half,
q1 * sin_half + q2 * cos_half
], dim=-1)
k_rot = torch.cat([
k1 * cos_half - k2 * sin_half,
k1 * sin_half + k2 * cos_half
], dim=-1)
return q_rot, k_rot
def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor:
"""
Entropia media della distribuzione alpha sui K rami.
alpha: (B, T, K)
Ritorna entropia normalizzata in [0, 1] circa.
"""
if alpha is None:
return torch.tensor(0.0, device="cpu")
eps = 1e-9
k = alpha.size(-1)
ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1) # (B, T)
norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device)))
return norm_ent.mean()
class DepthwiseCausalConv1d(nn.Module):
"""
Depthwise 1D causal convolution sulla dimensione di sequenza.
Input: (B, T, H) -> output: (B, T, H)
groups=H per avere un filtro per canale.
"""
def __init__(self, channels: int, kernel_size: int = 3):
super().__init__()
assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd"
self.kernel_size = kernel_size
self.pad = kernel_size - 1
self.conv = nn.Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=kernel_size,
padding=0,
groups=channels,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, H) -> (B, H, T)
x_c = x.transpose(1, 2)
# left-pad con zeri per causalità
x_c = F.pad(x_c, (self.pad, 0))
y = self.conv(x_c)
y = y.transpose(1, 2)
return y
class ChannelAttention(nn.Module):
"""
Attenzione per-canale (tipo SE) per token.
"""
def __init__(self, channels: int, reduction: int = 4):
super().__init__()
hidden = max(channels // reduction, 1)
self.ln = nn.LayerNorm(channels)
self.fc1 = nn.Linear(channels, hidden)
self.fc2 = nn.Linear(hidden, channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
g = self.ln(x)
g = F.gelu(self.fc1(g))
g = torch.sigmoid(self.fc2(g))
return x * g
class Fp32LayerNorm(nn.Module):
"""
LayerNorm in float32 per stabilità numerica, castando avanti/indietro.
I parametri rimangono in float32.
"""
def __init__(self, normalized_shape: int, eps: float = 1e-5):
super().__init__()
self.ln = nn.LayerNorm(normalized_shape, eps=eps)
self.ln.to(dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
# Disable autocast to prevent BF16/FP16 from being injected into LayerNorm
if x.is_cuda:
with torch.autocast(device_type="cuda", enabled=False):
y = self.ln(x.to(torch.float32))
else:
with torch.autocast(device_type="cpu", enabled=False):
y = self.ln(x.to(torch.float32))
return y.to(orig_dtype)
# --- Rami base per il PolymorphicMLP ---
class SwigluMLP(nn.Module):
def __init__(self, hidden_size: int, mlp_mult: float):
super().__init__()
mlp_dim = int(round(mlp_mult * hidden_size))
self.mlp_dim = mlp_dim
self.up = nn.Linear(hidden_size, 2 * mlp_dim)
self.down = nn.Linear(mlp_dim, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
up = self.up(x)
a, b = up.split(self.mlp_dim, dim=-1)
y = F.silu(a) * b
return self.down(y)
class GluMLP(nn.Module):
def __init__(self, hidden_size: int, mlp_mult: float):
super().__init__()
mlp_dim = int(round(mlp_mult * hidden_size))
self.mlp_dim = mlp_dim
self.up = nn.Linear(hidden_size, 2 * mlp_dim)
self.down = nn.Linear(mlp_dim, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
up = self.up(x)
a, b = up.split(self.mlp_dim, dim=-1)
y = torch.sigmoid(a) * b
return self.down(y)
class DepthwiseConvBranch(nn.Module):
def __init__(self, hidden_size: int, mlp_mult: float = 4.0):
super().__init__()
mlp_dim = int(round(mlp_mult * hidden_size))
self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3)
self.expand = nn.Linear(hidden_size, mlp_dim)
self.act = nn.GELU()
self.contract = nn.Linear(mlp_dim, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.dw(x)
y = self.expand(y)
y = self.act(y)
return self.contract(y)
class PolymorphicMLP(nn.Module):
"""
MLP polimorfico:
- Router: produce alpha (B, T, K)
- K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv)
- Output: somma pesata dei rami
- Opzionale ChannelAttention
- Espone:
- last_alpha (B, T, K) per logging
- last_aux (entropia normalizzata media) per aux-loss
- force_func: se >= 0, forza un solo ramo (debug / training per ramo)
"""
def __init__(
self,
hidden_size: int,
mlp_mult: float = 4.0,
num_funcs: int = 3,
router_dim: Optional[int] = None,
dropout: float = 0.0,
use_channel_attention: bool = False,
router_tau: float = 1.0,
):
super().__init__()
assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base"
self.hidden_size = hidden_size
self.mlp_mult = mlp_mult
self.num_funcs = num_funcs
# Router
r_dim = router_dim or hidden_size
self.router = nn.Sequential(
nn.Linear(hidden_size, r_dim),
nn.GELU(),
nn.Linear(r_dim, num_funcs),
)
self.router_tau = router_tau
# Inizializza router con pesi piccoli per distribuzioni più uniformi
for m in self.router.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
# Rami base (primi 3: compatibili con la tua v1)
funcs: List[nn.Module] = []
if num_funcs >= 1:
funcs.append(SwigluMLP(hidden_size, mlp_mult))
if num_funcs >= 2:
funcs.append(GluMLP(hidden_size, mlp_mult))
if num_funcs >= 3:
funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult))
# Se in futuro alzi num_funcs > 3, dovrai aggiungere nuovi rami qui
# (es. un MLP più profondo, un branch più conv-heavy, ecc.)
while len(funcs) < num_funcs:
funcs.append(SwigluMLP(hidden_size, mlp_mult)) # fallback: extra-swiglu
self.funcs = nn.ModuleList(funcs)
self.dropout = nn.Dropout(dropout)
self.use_channel_attention = use_channel_attention
self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None
# Monitoring
self.last_alpha: Optional[torch.Tensor] = None
self.last_aux: Optional[torch.Tensor] = None
# Forzatura di un singolo ramo (es. per debug / fasi speciali)
self.force_func: int = -1
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Router: (B, T, H) -> (B, T, K)
logits = self.router(x)
tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0
alpha = F.softmax(logits / tau, dim=-1)
# Forza un solo ramo se richiesto
if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs:
one_hot = torch.zeros_like(alpha)
one_hot[..., self.force_func] = 1.0
alpha = one_hot
# Rami
ys = [f(x) for f in self.funcs] # lista di (B, T, H)
y_stack = torch.stack(ys, dim=2) # (B, T, K, H)
alpha_exp = alpha.unsqueeze(-1) # (B, T, K, 1)
y = (alpha_exp * y_stack).sum(dim=2) # (B, T, H)
if self.use_channel_attention and self.chan_attn is not None:
y = self.chan_attn(y)
y = self.dropout(y)
# Monitoring
self.last_alpha = alpha.detach()
self.last_aux = None
if self.training:
# Token-level entropy encourages mixing at each position
token_ent = router_aux_loss(alpha)
# Global entropy over mean usage encourages balanced branch usage overall
p = alpha.mean(dim=(0, 1)) # (K,)
ent = -(p * (p + 1e-9).log()).sum()
k = alpha.size(-1)
global_ent = ent / math.log(float(k))
# Combine both to stabilize early training and avoid collapse
self.last_aux = 0.5 * token_ent + 0.5 * global_ent
return y, alpha