MhaWay commited on
Commit
b69b7b7
·
verified ·
1 Parent(s): ae57fc0

Create modeling_components.py

Browse files
Files changed (1) hide show
  1. src/veronica/modeling_components.py +295 -0
src/veronica/modeling_components.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
10
+ """
11
+ Applica Rotary Positional Embeddings (RoPE) a query e key.
12
+
13
+ Args:
14
+ q: Query tensor di shape (B, nh, T, hd)
15
+ k: Key tensor di shape (B, nh, T, hd)
16
+ cos: Cosine values di shape (1, 1, T, hd)
17
+ sin: Sine values di shape (1, 1, T, hd)
18
+
19
+ Returns:
20
+ Tuple[torch.Tensor, torch.Tensor]: (q_rotated, k_rotated)
21
+ """
22
+ # Dividi le dimensioni in metà per rotazione complessa
23
+ # q, k: (B, nh, T, hd) -> split in (B, nh, T, hd/2) pairs
24
+ hd = q.shape[-1]
25
+ assert hd % 2 == 0, "head_dim deve essere pari per RoPE"
26
+
27
+ q1, q2 = q[..., :hd//2], q[..., hd//2:]
28
+ k1, k2 = k[..., :hd//2], k[..., hd//2:]
29
+
30
+ # Applica rotazione: [cos*q1 - sin*q2, sin*q1 + cos*q2]
31
+ cos_half = cos[..., :hd//2]
32
+ sin_half = sin[..., :hd//2]
33
+
34
+ q_rot = torch.cat([
35
+ q1 * cos_half - q2 * sin_half,
36
+ q1 * sin_half + q2 * cos_half
37
+ ], dim=-1)
38
+
39
+ k_rot = torch.cat([
40
+ k1 * cos_half - k2 * sin_half,
41
+ k1 * sin_half + k2 * cos_half
42
+ ], dim=-1)
43
+
44
+ return q_rot, k_rot
45
+
46
+
47
+ def router_aux_loss(alpha: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Entropia media della distribuzione alpha sui K rami.
50
+ alpha: (B, T, K)
51
+ Ritorna entropia normalizzata in [0, 1] circa.
52
+ """
53
+ if alpha is None:
54
+ return torch.tensor(0.0, device="cpu")
55
+ eps = 1e-9
56
+ k = alpha.size(-1)
57
+ ent = -(alpha * (alpha.clamp_min(eps)).log()).sum(dim=-1) # (B, T)
58
+ norm_ent = ent / (torch.log(torch.tensor(float(k), device=alpha.device)))
59
+ return norm_ent.mean()
60
+
61
+
62
+ class DepthwiseCausalConv1d(nn.Module):
63
+ """
64
+ Depthwise 1D causal convolution sulla dimensione di sequenza.
65
+
66
+ Input: (B, T, H) -> output: (B, T, H)
67
+ groups=H per avere un filtro per canale.
68
+ """
69
+
70
+ def __init__(self, channels: int, kernel_size: int = 3):
71
+ super().__init__()
72
+ assert kernel_size >= 1 and kernel_size % 2 == 1, "kernel_size should be odd"
73
+ self.kernel_size = kernel_size
74
+ self.pad = kernel_size - 1
75
+ self.conv = nn.Conv1d(
76
+ in_channels=channels,
77
+ out_channels=channels,
78
+ kernel_size=kernel_size,
79
+ padding=0,
80
+ groups=channels,
81
+ )
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ # x: (B, T, H) -> (B, H, T)
85
+ x_c = x.transpose(1, 2)
86
+ # left-pad con zeri per causalità
87
+ x_c = F.pad(x_c, (self.pad, 0))
88
+ y = self.conv(x_c)
89
+ y = y.transpose(1, 2)
90
+ return y
91
+
92
+
93
+ class ChannelAttention(nn.Module):
94
+ """
95
+ Attenzione per-canale (tipo SE) per token.
96
+ """
97
+
98
+ def __init__(self, channels: int, reduction: int = 4):
99
+ super().__init__()
100
+ hidden = max(channels // reduction, 1)
101
+ self.ln = nn.LayerNorm(channels)
102
+ self.fc1 = nn.Linear(channels, hidden)
103
+ self.fc2 = nn.Linear(hidden, channels)
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ g = self.ln(x)
107
+ g = F.gelu(self.fc1(g))
108
+ g = torch.sigmoid(self.fc2(g))
109
+ return x * g
110
+
111
+
112
+ class Fp32LayerNorm(nn.Module):
113
+ """
114
+ LayerNorm in float32 per stabilità numerica, castando avanti/indietro.
115
+ I parametri rimangono in float32.
116
+ """
117
+
118
+ def __init__(self, normalized_shape: int, eps: float = 1e-5):
119
+ super().__init__()
120
+ self.ln = nn.LayerNorm(normalized_shape, eps=eps)
121
+ self.ln.to(dtype=torch.float32)
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ orig_dtype = x.dtype
125
+ # Disable autocast to prevent BF16/FP16 from being injected into LayerNorm
126
+ if x.is_cuda:
127
+ with torch.autocast(device_type="cuda", enabled=False):
128
+ y = self.ln(x.to(torch.float32))
129
+ else:
130
+ with torch.autocast(device_type="cpu", enabled=False):
131
+ y = self.ln(x.to(torch.float32))
132
+ return y.to(orig_dtype)
133
+
134
+
135
+ # --- Rami base per il PolymorphicMLP ---
136
+
137
+
138
+ class SwigluMLP(nn.Module):
139
+ def __init__(self, hidden_size: int, mlp_mult: float):
140
+ super().__init__()
141
+ mlp_dim = int(round(mlp_mult * hidden_size))
142
+ self.mlp_dim = mlp_dim
143
+ self.up = nn.Linear(hidden_size, 2 * mlp_dim)
144
+ self.down = nn.Linear(mlp_dim, hidden_size)
145
+
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ up = self.up(x)
148
+ a, b = up.split(self.mlp_dim, dim=-1)
149
+ y = F.silu(a) * b
150
+ return self.down(y)
151
+
152
+
153
+ class GluMLP(nn.Module):
154
+ def __init__(self, hidden_size: int, mlp_mult: float):
155
+ super().__init__()
156
+ mlp_dim = int(round(mlp_mult * hidden_size))
157
+ self.mlp_dim = mlp_dim
158
+ self.up = nn.Linear(hidden_size, 2 * mlp_dim)
159
+ self.down = nn.Linear(mlp_dim, hidden_size)
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ up = self.up(x)
163
+ a, b = up.split(self.mlp_dim, dim=-1)
164
+ y = torch.sigmoid(a) * b
165
+ return self.down(y)
166
+
167
+
168
+ class DepthwiseConvBranch(nn.Module):
169
+ def __init__(self, hidden_size: int, mlp_mult: float = 4.0):
170
+ super().__init__()
171
+ mlp_dim = int(round(mlp_mult * hidden_size))
172
+ self.dw = DepthwiseCausalConv1d(hidden_size, kernel_size=3)
173
+ self.expand = nn.Linear(hidden_size, mlp_dim)
174
+ self.act = nn.GELU()
175
+ self.contract = nn.Linear(mlp_dim, hidden_size)
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ y = self.dw(x)
179
+ y = self.expand(y)
180
+ y = self.act(y)
181
+ return self.contract(y)
182
+
183
+
184
+ class PolymorphicMLP(nn.Module):
185
+ """
186
+ MLP polimorfico:
187
+
188
+ - Router: produce alpha (B, T, K)
189
+ - K rami base in una ModuleList (es. SwiGLU, GLU, depthwise-conv)
190
+ - Output: somma pesata dei rami
191
+ - Opzionale ChannelAttention
192
+ - Espone:
193
+ - last_alpha (B, T, K) per logging
194
+ - last_aux (entropia normalizzata media) per aux-loss
195
+ - force_func: se >= 0, forza un solo ramo (debug / training per ramo)
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ hidden_size: int,
201
+ mlp_mult: float = 4.0,
202
+ num_funcs: int = 3,
203
+ router_dim: Optional[int] = None,
204
+ dropout: float = 0.0,
205
+ use_channel_attention: bool = False,
206
+ router_tau: float = 1.0,
207
+ ):
208
+ super().__init__()
209
+ assert num_funcs >= 1, "PolymorphicMLP richiede almeno 1 funzione di base"
210
+ self.hidden_size = hidden_size
211
+ self.mlp_mult = mlp_mult
212
+ self.num_funcs = num_funcs
213
+
214
+ # Router
215
+ r_dim = router_dim or hidden_size
216
+ self.router = nn.Sequential(
217
+ nn.Linear(hidden_size, r_dim),
218
+ nn.GELU(),
219
+ nn.Linear(r_dim, num_funcs),
220
+ )
221
+ self.router_tau = router_tau
222
+
223
+ # Inizializza router con pesi piccoli per distribuzioni più uniformi
224
+ for m in self.router.modules():
225
+ if isinstance(m, nn.Linear):
226
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
227
+ if m.bias is not None:
228
+ nn.init.zeros_(m.bias)
229
+
230
+ # Rami base (primi 3: compatibili con la tua v1)
231
+ funcs: List[nn.Module] = []
232
+ if num_funcs >= 1:
233
+ funcs.append(SwigluMLP(hidden_size, mlp_mult))
234
+ if num_funcs >= 2:
235
+ funcs.append(GluMLP(hidden_size, mlp_mult))
236
+ if num_funcs >= 3:
237
+ funcs.append(DepthwiseConvBranch(hidden_size, mlp_mult))
238
+ # Se in futuro alzi num_funcs > 3, dovrai aggiungere nuovi rami qui
239
+ # (es. un MLP più profondo, un branch più conv-heavy, ecc.)
240
+ while len(funcs) < num_funcs:
241
+ funcs.append(SwigluMLP(hidden_size, mlp_mult)) # fallback: extra-swiglu
242
+
243
+ self.funcs = nn.ModuleList(funcs)
244
+
245
+ self.dropout = nn.Dropout(dropout)
246
+ self.use_channel_attention = use_channel_attention
247
+ self.chan_attn = ChannelAttention(hidden_size) if use_channel_attention else None
248
+
249
+ # Monitoring
250
+ self.last_alpha: Optional[torch.Tensor] = None
251
+ self.last_aux: Optional[torch.Tensor] = None
252
+
253
+ # Forzatura di un singolo ramo (es. per debug / fasi speciali)
254
+ self.force_func: int = -1
255
+
256
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ # Router: (B, T, H) -> (B, T, K)
258
+ logits = self.router(x)
259
+ tau = float(self.router_tau) if self.router_tau is not None and self.router_tau > 0.0 else 1.0
260
+ alpha = F.softmax(logits / tau, dim=-1)
261
+
262
+ # Forza un solo ramo se richiesto
263
+ if self.force_func is not None and self.force_func >= 0 and self.force_func < self.num_funcs:
264
+ one_hot = torch.zeros_like(alpha)
265
+ one_hot[..., self.force_func] = 1.0
266
+ alpha = one_hot
267
+
268
+ # Rami
269
+ ys = [f(x) for f in self.funcs] # lista di (B, T, H)
270
+ y_stack = torch.stack(ys, dim=2) # (B, T, K, H)
271
+
272
+ alpha_exp = alpha.unsqueeze(-1) # (B, T, K, 1)
273
+ y = (alpha_exp * y_stack).sum(dim=2) # (B, T, H)
274
+
275
+ if self.use_channel_attention and self.chan_attn is not None:
276
+ y = self.chan_attn(y)
277
+
278
+ y = self.dropout(y)
279
+
280
+ # Monitoring
281
+ self.last_alpha = alpha.detach()
282
+
283
+ self.last_aux = None
284
+ if self.training:
285
+ # Token-level entropy encourages mixing at each position
286
+ token_ent = router_aux_loss(alpha)
287
+ # Global entropy over mean usage encourages balanced branch usage overall
288
+ p = alpha.mean(dim=(0, 1)) # (K,)
289
+ ent = -(p * (p + 1e-9).log()).sum()
290
+ k = alpha.size(-1)
291
+ global_ent = ent / math.log(float(k))
292
+ # Combine both to stabilize early training and avoid collapse
293
+ self.last_aux = 0.5 * token_ent + 0.5 * global_ent
294
+
295
+ return y, alpha