|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Applica Rotary Positional Embeddings (RoPE) a query e key. |
|
|
|
|
|
Args: |
|
|
q: Query tensor di shape (B, nh, T, hd) |
|
|
k: Key tensor di shape (B, nh, T, hd) |
|
|
cos: Cosine values di shape (1, 1, T, hd) |
|
|
sin: Sine values di shape (1, 1, T, hd) |
|
|
|
|
|
Returns: |
|
|
Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated) |
|
|
""" |
|
|
|
|
|
|
|
|
hd = q.shape[-1] |
|
|
assert hd % 2 == 0, "head_dim deve essere pari per RoPE" |
|
|
|
|
|
q1, q2 = q[..., :hd//2], q[..., hd//2:] |
|
|
k1, k2 = k[..., :hd//2], k[..., hd//2:] |
|
|
|
|
|
|
|
|
cos_half = cos[..., :hd//2] |
|
|
sin_half = sin[..., :hd//2] |
|
|
|
|
|
q_rot = torch.cat([ |
|
|
q1 * cos_half - q2 * sin_half, |
|
|
q1 * sin_half + q2 * cos_half |
|
|
], dim=-1) |
|
|
|
|
|
k_rot = torch.cat([ |
|
|
k1 * cos_half - k2 * sin_half, |
|
|
k1 * sin_half + k2 * cos_half |
|
|
], dim=-1) |
|
|
|
|
|
return q_rot, k_rot |
|
|
|
|
|
|
|
|
def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Entropia media della distribuzione alpha sui K rami. |
|
|
alpha: (B, T, K) |
|
|
Ritorna entropia normalizzata in [0, 1] circa. |
|
|
""" |
|
|
if alpha is None: |
|
|
return torch.tensor(0.0, device="cpu") |
|
|
eps = 1e-9 |
|
|
k = alpha.size(-1) |
|
|
ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1) |
|
|
norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device))) |
|
|
return norm_ent.mean() |
|
|
|
|
|
|
|
|
class DepthwiseCausalConv1d(nn.Module): |
|
|
""" |
|
|
Depthwise 1D causal convolution sulla dimensione di sequenza. |
|
|
|
|
|
Input: (B, T, H) -> output: (B, T, H) |
|
|
groups=H per avere un filtro per canale. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels: int, kernel_size: int = 3): |
|
|
super().__init__() |
|
|
assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd" |
|
|
self.kernel_size = kernel_size |
|
|
self.pad = kernel_size - 1 |
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=channels, |
|
|
out_channels=channels, |
|
|
kernel_size=kernel_size, |
|
|
padding=0, |
|
|
groups=channels, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x_c = x.transpose(1, 2) |
|
|
|
|
|
x_c = F.pad(x_c, (self.pad, 0)) |
|
|
y = self.conv(x_c) |
|
|
y = y.transpose(1, 2) |
|
|
return y |
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
""" |
|
|
Attenzione per-canale (tipo SE) per token. |
|
|
""" |
|
|
|
|
|
def __init__(self, channels: int, reduction: int = 4): |
|
|
super().__init__() |
|
|
hidden = max(channels // reduction, 1) |
|
|
self.ln = nn.LayerNorm(channels) |
|
|
self.fc1 = nn.Linear(channels, hidden) |
|
|
self.fc2 = nn.Linear(hidden, channels) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
g = self.ln(x) |
|
|
g = F.gelu(self.fc1(g)) |
|
|
g = torch.sigmoid(self.fc2(g)) |
|
|
return x * g |
|
|
|
|
|
|
|
|
class Fp32LayerNorm(nn.Module): |
|
|
""" |
|
|
LayerNorm in float32 per stabilità numerica, castando avanti/indietro. |
|
|
I parametri rimangono in float32. |
|
|
""" |
|
|
|
|
|
def __init__(self, normalized_shape: int, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.ln = nn.LayerNorm(normalized_shape, eps=eps) |
|
|
self.ln.to(dtype=torch.float32) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
orig_dtype = x.dtype |
|
|
|
|
|
if x.is_cuda: |
|
|
with torch.autocast(device_type="cuda", enabled=False): |
|
|
y = self.ln(x.to(torch.float32)) |
|
|
else: |
|
|
with torch.autocast(device_type="cpu", enabled=False): |
|
|
y = self.ln(x.to(torch.float32)) |
|
|
return y.to(orig_dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwigluMLP(nn.Module): |
|
|
def __init__(self, hidden_size: int, mlp_mult: float): |
|
|
super().__init__() |
|
|
mlp_dim = int(round(mlp_mult * hidden_size)) |
|
|
self.mlp_dim = mlp_dim |
|
|
self.up = nn.Linear(hidden_size, 2 * mlp_dim) |
|
|
self.down = nn.Linear(mlp_dim, hidden_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
up = self.up(x) |
|
|
a, b = up.split(self.mlp_dim, dim=-1) |
|
|
y = F.silu(a) * b |
|
|
return self.down(y) |
|
|
|
|
|
|
|
|
class GluMLP(nn.Module): |
|
|
def __init__(self, hidden_size: int, mlp_mult: float): |
|
|
super().__init__() |
|
|
mlp_dim = int(round(mlp_mult * hidden_size)) |
|
|
self.mlp_dim = mlp_dim |
|
|
self.up = nn.Linear(hidden_size, 2 * mlp_dim) |
|
|
self.down = nn.Linear(mlp_dim, hidden_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
up = self.up(x) |
|
|
a, b = up.split(self.mlp_dim, dim=-1) |
|
|
y = torch.sigmoid(a) * b |
|
|
return self.down(y) |
|
|
|
|
|
|
|
|
class DepthwiseConvBranch(nn.Module): |
|
|
def __init__(self, hidden_size: int, mlp_mult: float = 4.0): |
|
|
super().__init__() |
|
|
mlp_dim = int(round(mlp_mult * hidden_size)) |
|
|
self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3) |
|
|
self.expand = nn.Linear(hidden_size, mlp_dim) |
|
|
self.act = nn.GELU() |
|
|
self.contract = nn.Linear(mlp_dim, hidden_size) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
y = self.dw(x) |
|
|
y = self.expand(y) |
|
|
y = self.act(y) |
|
|
return self.contract(y) |
|
|
|
|
|
|
|
|
class PolymorphicMLP(nn.Module): |
|
|
""" |
|
|
MLP polimorfico: |
|
|
|
|
|
- Router: produce alpha (B, T, K) |
|
|
- K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv) |
|
|
- Output: somma pesata dei rami |
|
|
- Opzionale ChannelAttention |
|
|
- Espone: |
|
|
- last_alpha (B, T, K) per logging |
|
|
- last_aux (entropia normalizzata media) per aux-loss |
|
|
- force_func: se >= 0, forza un solo ramo (debug / training per ramo) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
mlp_mult: float = 4.0, |
|
|
num_funcs: int = 3, |
|
|
router_dim: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
use_channel_attention: bool = False, |
|
|
router_tau: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base" |
|
|
self.hidden_size = hidden_size |
|
|
self.mlp_mult = mlp_mult |
|
|
self.num_funcs = num_funcs |
|
|
|
|
|
|
|
|
r_dim = router_dim or hidden_size |
|
|
self.router = nn.Sequential( |
|
|
nn.Linear(hidden_size, r_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(r_dim, num_funcs), |
|
|
) |
|
|
self.router_tau = router_tau |
|
|
|
|
|
|
|
|
for m in self.router.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
|
|
|
funcs: List[nn.Module] = [] |
|
|
if num_funcs >= 1: |
|
|
funcs.append(SwigluMLP(hidden_size, mlp_mult)) |
|
|
if num_funcs >= 2: |
|
|
funcs.append(GluMLP(hidden_size, mlp_mult)) |
|
|
if num_funcs >= 3: |
|
|
funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult)) |
|
|
|
|
|
|
|
|
while len(funcs) < num_funcs: |
|
|
funcs.append(SwigluMLP(hidden_size, mlp_mult)) |
|
|
|
|
|
self.funcs = nn.ModuleList(funcs) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.use_channel_attention = use_channel_attention |
|
|
self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None |
|
|
|
|
|
|
|
|
self.last_alpha: Optional[torch.Tensor] = None |
|
|
self.last_aux: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
self.force_func: int = -1 |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
logits = self.router(x) |
|
|
tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0 |
|
|
alpha = F.softmax(logits / tau, dim=-1) |
|
|
|
|
|
|
|
|
if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs: |
|
|
one_hot = torch.zeros_like(alpha) |
|
|
one_hot[..., self.force_func] = 1.0 |
|
|
alpha = one_hot |
|
|
|
|
|
|
|
|
ys = [f(x) for f in self.funcs] |
|
|
y_stack = torch.stack(ys, dim=2) |
|
|
|
|
|
alpha_exp = alpha.unsqueeze(-1) |
|
|
y = (alpha_exp * y_stack).sum(dim=2) |
|
|
|
|
|
if self.use_channel_attention and self.chan_attn is not None: |
|
|
y = self.chan_attn(y) |
|
|
|
|
|
y = self.dropout(y) |
|
|
|
|
|
|
|
|
self.last_alpha = alpha.detach() |
|
|
|
|
|
self.last_aux = None |
|
|
if self.training: |
|
|
|
|
|
token_ent = router_aux_loss(alpha) |
|
|
|
|
|
p = alpha.mean(dim=(0, 1)) |
|
|
ent = -(p * (p + 1e-9).log()).sum() |
|
|
k = alpha.size(-1) |
|
|
global_ent = ent / math.log(float(k)) |
|
|
|
|
|
self.last_aux = 0.5 * token_ent + 0.5 * global_ent |
|
|
|
|
|
return y, alpha |
|
|
|