File size: 9,914 Bytes
b69b7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
from typing import Optional, Tuple, List

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Applica Rotary Positional Embeddings (RoPE) a query e key.
    
    Args:
        q: Query tensor di shape (B, nh, T, hd)
        k: Key tensor di shape (B, nh, T, hd)
        cos: Cosine values di shape (1, 1, T, hd)
        sin: Sine values di shape (1, 1, T, hd)
    
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated)
    """
    # Dividi le dimensioni in metà per rotazione complessa
    # q, k: (B, nh, T, hd) -> split in (B, nh, T, hd/2) pairs
    hd = q.shape[-1]
    assert hd % 2 == 0, "head_dim deve essere pari per RoPE"
    
    q1, q2 = q[..., :hd//2], q[..., hd//2:]
    k1, k2 = k[..., :hd//2], k[..., hd//2:]
    
    # Applica rotazione: [cos*q1 - sin*q2, sin*q1 + cos*q2]
    cos_half = cos[..., :hd//2]
    sin_half = sin[..., :hd//2]
    
    q_rot = torch.cat([
        q1 * cos_half - q2 * sin_half,
        q1 * sin_half + q2 * cos_half
    ], dim=-1)
    
    k_rot = torch.cat([
        k1 * cos_half - k2 * sin_half,
        k1 * sin_half + k2 * cos_half
    ], dim=-1)
    
    return q_rot, k_rot


def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor:
    """
    Entropia media della distribuzione alpha sui K rami.
    alpha: (B, T, K)
    Ritorna entropia normalizzata in [0, 1] circa.
    """
    if alpha is None:
        return torch.tensor(0.0, device="cpu")
    eps = 1e-9
    k = alpha.size(-1)
    ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1)  # (B, T)
    norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device)))
    return norm_ent.mean()


class DepthwiseCausalConv1d(nn.Module):
    """
    Depthwise 1D causal convolution sulla dimensione di sequenza.

    Input: (B, T, H) -> output: (B, T, H)
    groups=H per avere un filtro per canale.
    """

    def __init__(self, channels: int, kernel_size: int = 3):
        super().__init__()
        assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd"
        self.kernel_size = kernel_size
        self.pad = kernel_size - 1
        self.conv = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=0,
            groups=channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, H) -> (B, H, T)
        x_c = x.transpose(1, 2)
        # left-pad con zeri per causalità
        x_c = F.pad(x_c, (self.pad, 0))
        y = self.conv(x_c)
        y = y.transpose(1, 2)
        return y


class ChannelAttention(nn.Module):
    """
    Attenzione per-canale (tipo SE) per token.
    """

    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        hidden = max(channels // reduction, 1)
        self.ln = nn.LayerNorm(channels)
        self.fc1 = nn.Linear(channels, hidden)
        self.fc2 = nn.Linear(hidden, channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        g = self.ln(x)
        g = F.gelu(self.fc1(g))
        g = torch.sigmoid(self.fc2(g))
        return x * g


class Fp32LayerNorm(nn.Module):
    """
    LayerNorm in float32 per stabilità numerica, castando avanti/indietro.
    I parametri rimangono in float32.
    """

    def __init__(self, normalized_shape: int, eps: float = 1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(normalized_shape, eps=eps)
        self.ln.to(dtype=torch.float32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_dtype = x.dtype
        # Disable autocast to prevent BF16/FP16 from being injected into LayerNorm
        if x.is_cuda:
            with torch.autocast(device_type="cuda", enabled=False):
                y = self.ln(x.to(torch.float32))
        else:
            with torch.autocast(device_type="cpu", enabled=False):
                y = self.ln(x.to(torch.float32))
        return y.to(orig_dtype)


# --- Rami base per il PolymorphicMLP ---


class SwigluMLP(nn.Module):
    def __init__(self, hidden_size: int, mlp_mult: float):
        super().__init__()
        mlp_dim = int(round(mlp_mult * hidden_size))
        self.mlp_dim = mlp_dim
        self.up = nn.Linear(hidden_size, 2 * mlp_dim)
        self.down = nn.Linear(mlp_dim, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up = self.up(x)
        a, b = up.split(self.mlp_dim, dim=-1)
        y = F.silu(a) * b
        return self.down(y)


class GluMLP(nn.Module):
    def __init__(self, hidden_size: int, mlp_mult: float):
        super().__init__()
        mlp_dim = int(round(mlp_mult * hidden_size))
        self.mlp_dim = mlp_dim
        self.up = nn.Linear(hidden_size, 2 * mlp_dim)
        self.down = nn.Linear(mlp_dim, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up = self.up(x)
        a, b = up.split(self.mlp_dim, dim=-1)
        y = torch.sigmoid(a) * b
        return self.down(y)


class DepthwiseConvBranch(nn.Module):
    def __init__(self, hidden_size: int, mlp_mult: float = 4.0):
        super().__init__()
        mlp_dim = int(round(mlp_mult * hidden_size))
        self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3)
        self.expand = nn.Linear(hidden_size, mlp_dim)
        self.act = nn.GELU()
        self.contract = nn.Linear(mlp_dim, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.dw(x)
        y = self.expand(y)
        y = self.act(y)
        return self.contract(y)


class PolymorphicMLP(nn.Module):
    """
    MLP polimorfico:

    - Router: produce alpha (B, T, K)
    - K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv)
    - Output: somma pesata dei rami
    - Opzionale ChannelAttention
    - Espone:
        - last_alpha (B, T, K) per logging
        - last_aux (entropia normalizzata media) per aux-loss
        - force_func: se >= 0, forza un solo ramo (debug / training per ramo)
    """

    def __init__(
        self,
        hidden_size: int,
        mlp_mult: float = 4.0,
        num_funcs: int = 3,
        router_dim: Optional[int] = None,
        dropout: float = 0.0,
        use_channel_attention: bool = False,
        router_tau: float = 1.0,
    ):
        super().__init__()
        assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base"
        self.hidden_size = hidden_size
        self.mlp_mult = mlp_mult
        self.num_funcs = num_funcs

        # Router
        r_dim = router_dim or hidden_size
        self.router = nn.Sequential(
            nn.Linear(hidden_size, r_dim),
            nn.GELU(),
            nn.Linear(r_dim, num_funcs),
        )
        self.router_tau = router_tau
        
        # Inizializza router con pesi piccoli per distribuzioni più uniformi
        for m in self.router.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # Rami base (primi 3: compatibili con la tua v1)
        funcs: List[nn.Module] = []
        if num_funcs >= 1:
            funcs.append(SwigluMLP(hidden_size, mlp_mult))
        if num_funcs >= 2:
            funcs.append(GluMLP(hidden_size, mlp_mult))
        if num_funcs >= 3:
            funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult))
        # Se in futuro alzi num_funcs > 3, dovrai aggiungere nuovi rami qui
        # (es. un MLP più profondo, un branch più conv-heavy, ecc.)
        while len(funcs) < num_funcs:
            funcs.append(SwigluMLP(hidden_size, mlp_mult))  # fallback: extra-swiglu

        self.funcs = nn.ModuleList(funcs)

        self.dropout = nn.Dropout(dropout)
        self.use_channel_attention = use_channel_attention
        self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None

        # Monitoring
        self.last_alpha: Optional[torch.Tensor] = None
        self.last_aux: Optional[torch.Tensor] = None

        # Forzatura di un singolo ramo (es. per debug / fasi speciali)
        self.force_func: int = -1

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Router: (B, T, H) -> (B, T, K)
        logits = self.router(x)
        tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0
        alpha = F.softmax(logits / tau, dim=-1)

        # Forza un solo ramo se richiesto
        if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs:
            one_hot = torch.zeros_like(alpha)
            one_hot[..., self.force_func] = 1.0
            alpha = one_hot

        # Rami
        ys = [f(x) for f in self.funcs]  # lista di (B, T, H)
        y_stack = torch.stack(ys, dim=2)  # (B, T, K, H)

        alpha_exp = alpha.unsqueeze(-1)  # (B, T, K, 1)
        y = (alpha_exp * y_stack).sum(dim=2)  # (B, T, H)

        if self.use_channel_attention and self.chan_attn is not None:
            y = self.chan_attn(y)

        y = self.dropout(y)

        # Monitoring
        self.last_alpha = alpha.detach()

        self.last_aux = None
        if self.training:
            # Token-level entropy encourages mixing at each position
            token_ent = router_aux_loss(alpha)
            # Global entropy over mean usage encourages balanced branch usage overall
            p = alpha.mean(dim=(0, 1))  # (K,)
            ent = -(p * (p + 1e-9).log()).sum()
            k = alpha.size(-1)
            global_ent = ent / math.log(float(k))
            # Combine both to stabilize early training and avoid collapse
            self.last_aux = 0.5 * token_ent + 0.5 * global_ent

        return y, alpha