pcunwa commited on
Commit
76d35e4
·
verified ·
1 Parent(s): 7f01628

Upload 3 files

Browse files
v2_voc/bs_roformer.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from models.bs_roformer.attend import Attend
9
+ try:
10
+ from models.bs_roformer.attend_sage import Attend as AttendSage
11
+ except:
12
+ pass
13
+ from torch.utils.checkpoint import checkpoint
14
+
15
+ from beartype.typing import Tuple, Optional, List, Callable
16
+ from beartype import beartype
17
+
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange
22
+ import torchaudio
23
+ # helper functions
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(v, d):
30
+ return v if exists(v) else d
31
+
32
+
33
+ def pack_one(t, pattern):
34
+ return pack([t], pattern)
35
+
36
+
37
+ def unpack_one(t, ps, pattern):
38
+ return unpack(t, ps, pattern)[0]
39
+
40
+
41
+ # norm
42
+
43
+ def l2norm(t):
44
+ return F.normalize(t, dim = -1, p = 2)
45
+
46
+
47
+ class RMSNorm(Module):
48
+ def __init__(self, dim):
49
+ super().__init__()
50
+ self.scale = dim ** 0.5
51
+ self.gamma = nn.Parameter(torch.ones(dim))
52
+
53
+ def forward(self, x):
54
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
55
+
56
+
57
+ # attention
58
+
59
+ class FeedForward(Module):
60
+ def __init__(
61
+ self,
62
+ dim,
63
+ mult=4,
64
+ dropout=0.
65
+ ):
66
+ super().__init__()
67
+ dim_inner = int(dim * mult)
68
+ self.net = nn.Sequential(
69
+ RMSNorm(dim),
70
+ nn.Linear(dim, dim_inner),
71
+ nn.GELU(),
72
+ nn.Dropout(dropout),
73
+ nn.Linear(dim_inner, dim),
74
+ nn.Dropout(dropout)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+ class Attention(Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ heads=8,
85
+ dim_head=64,
86
+ dropout=0.,
87
+ rotary_embed=None,
88
+ flash=True,
89
+ sage_attention=False,
90
+ ):
91
+ super().__init__()
92
+ self.heads = heads
93
+ self.scale = dim_head ** -0.5
94
+ dim_inner = heads * dim_head
95
+
96
+ self.rotary_embed = rotary_embed
97
+
98
+ if sage_attention:
99
+ self.attend = AttendSage(flash=flash, dropout=dropout)
100
+ else:
101
+ self.attend = Attend(flash=flash, dropout=dropout)
102
+
103
+ self.norm = RMSNorm(dim)
104
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
105
+
106
+ self.to_gates = nn.Linear(dim, heads)
107
+
108
+ self.to_out = nn.Sequential(
109
+ nn.Linear(dim_inner, dim, bias=False),
110
+ nn.Dropout(dropout)
111
+ )
112
+
113
+ def forward(self, x):
114
+ x = self.norm(x)
115
+
116
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
117
+
118
+ if exists(self.rotary_embed):
119
+ q = self.rotary_embed.rotate_queries_or_keys(q)
120
+ k = self.rotary_embed.rotate_queries_or_keys(k)
121
+
122
+ out = self.attend(q, k, v)
123
+
124
+ gates = self.to_gates(x)
125
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
126
+
127
+ out = rearrange(out, 'b h n d -> b n (h d)')
128
+ return self.to_out(out)
129
+
130
+
131
+ class LinearAttention(Module):
132
+ """
133
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
134
+ """
135
+
136
+ @beartype
137
+ def __init__(
138
+ self,
139
+ *,
140
+ dim,
141
+ dim_head=32,
142
+ heads=8,
143
+ scale=8,
144
+ flash=False,
145
+ dropout=0.,
146
+ sage_attention=False,
147
+ ):
148
+ super().__init__()
149
+ dim_inner = dim_head * heads
150
+ self.norm = RMSNorm(dim)
151
+
152
+ self.to_qkv = nn.Sequential(
153
+ nn.Linear(dim, dim_inner * 3, bias=False),
154
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
155
+ )
156
+
157
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
+
159
+ if sage_attention:
160
+ self.attend = AttendSage(
161
+ scale=scale,
162
+ dropout=dropout,
163
+ flash=flash
164
+ )
165
+ else:
166
+ self.attend = Attend(
167
+ scale=scale,
168
+ dropout=dropout,
169
+ flash=flash
170
+ )
171
+
172
+ self.to_out = nn.Sequential(
173
+ Rearrange('b h d n -> b n (h d)'),
174
+ nn.Linear(dim_inner, dim, bias=False)
175
+ )
176
+
177
+ def forward(
178
+ self,
179
+ x
180
+ ):
181
+ x = self.norm(x)
182
+
183
+ q, k, v = self.to_qkv(x)
184
+
185
+ q, k = map(l2norm, (q, k))
186
+ q = q * self.temperature.exp()
187
+
188
+ out = self.attend(q, k, v)
189
+
190
+ return self.to_out(out)
191
+
192
+ class Transformer(Module):
193
+ def __init__(
194
+ self,
195
+ *,
196
+ dim,
197
+ depth,
198
+ dim_head=64,
199
+ heads=8,
200
+ attn_dropout=0.,
201
+ ff_dropout=0.,
202
+ ff_mult=4,
203
+ norm_output=True,
204
+ rotary_embed=None,
205
+ flash_attn=True,
206
+ linear_attn=False,
207
+ sage_attention=False,
208
+ ):
209
+ super().__init__()
210
+ self.layers = ModuleList([])
211
+
212
+ for _ in range(depth):
213
+ if linear_attn:
214
+ attn = LinearAttention(
215
+ dim=dim,
216
+ dim_head=dim_head,
217
+ heads=heads,
218
+ dropout=attn_dropout,
219
+ flash=flash_attn,
220
+ sage_attention=sage_attention
221
+ )
222
+ else:
223
+ attn = Attention(
224
+ dim=dim,
225
+ dim_head=dim_head,
226
+ heads=heads,
227
+ dropout=attn_dropout,
228
+ rotary_embed=rotary_embed,
229
+ flash=flash_attn,
230
+ sage_attention=sage_attention
231
+ )
232
+
233
+ self.layers.append(ModuleList([
234
+ attn,
235
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
236
+ ]))
237
+
238
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
239
+
240
+ def forward(self, x):
241
+
242
+ for attn, ff in self.layers:
243
+ x = attn(x) + x
244
+ x = ff(x) + x
245
+
246
+ return self.norm(x)
247
+
248
+
249
+ # bandsplit module
250
+
251
+
252
+
253
+ class BandSplit(Module):
254
+ @beartype
255
+ def __init__(
256
+ self,
257
+ dim,
258
+ dim_inputs: Tuple[int, ...]
259
+ ):
260
+ super().__init__()
261
+ self.dim_inputs = dim_inputs
262
+ self.to_features = ModuleList([])
263
+
264
+ for dim_in in dim_inputs:
265
+ net = nn.Sequential(
266
+ RMSNorm(dim_in),
267
+ nn.Linear(dim_in, dim)
268
+ )
269
+
270
+ self.to_features.append(net)
271
+
272
+ def forward(self, x):
273
+
274
+ x = x.split(self.dim_inputs, dim=-1)
275
+
276
+ outs = []
277
+ for split_input, to_feature in zip(x, self.to_features):
278
+ split_output = to_feature(split_input)
279
+ outs.append(split_output)
280
+
281
+ x = torch.stack(outs, dim=-2)
282
+
283
+ return x
284
+
285
+ class Conv(nn.Module):
286
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
287
+ super().__init__()
288
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
289
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
290
+ self.act = nn.SiLU() if act else nn.Identity()
291
+
292
+ def forward(self, x):
293
+ return self.act(self.bn(self.conv(x)))
294
+
295
+ def autopad(k, p=None):
296
+ if p is None:
297
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
298
+ return p
299
+
300
+ class DSConv(nn.Module):
301
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
302
+ super().__init__()
303
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
304
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
305
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
306
+ self.act = nn.SiLU() if act else nn.Identity()
307
+
308
+ def forward(self, x):
309
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
310
+
311
+ class DS_Bottleneck(nn.Module):
312
+ def __init__(self, c1, c2, k=3, shortcut=True):
313
+ super().__init__()
314
+ c_ = c1
315
+ self.dsconv1 = DSConv(c1, c_, k=3, s=1)
316
+ self.dsconv2 = DSConv(c_, c2, k=k, s=1)
317
+ self.shortcut = shortcut and c1 == c2
318
+
319
+ def forward(self, x):
320
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
321
+
322
+ class DS_C3k(nn.Module):
323
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
324
+ super().__init__()
325
+ c_ = int(c2 * e)
326
+ self.cv1 = Conv(c1, c_, 1, 1)
327
+ self.cv2 = Conv(c1, c_, 1, 1)
328
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
329
+ self.m = nn.Sequential(*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)])
330
+
331
+ def forward(self, x):
332
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
333
+
334
+ class DS_C3k2(nn.Module):
335
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
336
+ super().__init__()
337
+ c_ = int(c2 * e)
338
+ self.cv1 = Conv(c1, c_, 1, 1)
339
+ self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
340
+ self.cv2 = Conv(c_, c2, 1, 1)
341
+
342
+ def forward(self, x):
343
+ x_ = self.cv1(x)
344
+ x_ = self.m(x_)
345
+ return self.cv2(x_)
346
+
347
+ class AdaptiveHyperedgeGeneration(nn.Module):
348
+ def __init__(self, in_channels, num_hyperedges, num_heads=8):
349
+ super().__init__()
350
+ self.num_hyperedges = num_hyperedges
351
+ self.num_heads = num_heads
352
+ self.head_dim = in_channels // num_heads
353
+
354
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
355
+
356
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
357
+
358
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
359
+
360
+ self.scale = self.head_dim ** -0.5
361
+
362
+ def forward(self, x):
363
+ B, N, C = x.shape
364
+
365
+ f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
366
+ f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
367
+ f_ctx = torch.cat((f_avg, f_max), dim=1)
368
+
369
+ delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
370
+ P = self.global_proto.unsqueeze(0) + delta_P
371
+
372
+ z = self.query_proj(x)
373
+
374
+ z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
375
+
376
+ P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)
377
+
378
+ sim = (z @ P) * self.scale
379
+
380
+ s_bar = sim.mean(dim=1)
381
+
382
+ A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
383
+
384
+ return A
385
+
386
+ class HypergraphConvolution(nn.Module):
387
+ def __init__(self, in_channels, out_channels):
388
+ super().__init__()
389
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
390
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
391
+ self.act = nn.SiLU()
392
+
393
+ def forward(self, x, A):
394
+ f_m = torch.bmm(A, x)
395
+ f_m = self.act(self.W_e(f_m))
396
+
397
+ x_out = torch.bmm(A.transpose(1, 2), f_m)
398
+ x_out = self.act(self.W_v(x_out))
399
+
400
+ return x + x_out
401
+
402
+ class AdaptiveHypergraphComputation(nn.Module):
403
+ def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
404
+ super().__init__()
405
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
406
+ in_channels, num_hyperedges, num_heads
407
+ )
408
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
409
+
410
+ def forward(self, x):
411
+ B, C, H, W = x.shape
412
+ x_flat = x.flatten(2).permute(0, 2, 1)
413
+
414
+ A = self.adaptive_hyperedge_gen(x_flat)
415
+
416
+ x_out_flat = self.hypergraph_conv(x_flat, A)
417
+
418
+ x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
419
+ return x_out
420
+
421
+ class C3AH(nn.Module):
422
+ def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
423
+ super().__init__()
424
+ c_ = int(c1 * e)
425
+ self.cv1 = Conv(c1, c_, 1, 1)
426
+ self.cv2 = Conv(c1, c_, 1, 1)
427
+ self.ahc = AdaptiveHypergraphComputation(
428
+ c_, c_, num_hyperedges, num_heads
429
+ )
430
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
431
+
432
+ def forward(self, x):
433
+ x_lateral = self.cv1(x)
434
+ x_ahc = self.ahc(self.cv2(x))
435
+ return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
436
+
437
+ class HyperACE(nn.Module):
438
+ def __init__(self, in_channels: List[int], out_channels: int,
439
+ num_hyperedges=8, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25):
440
+ super().__init__()
441
+
442
+ c2, c3, c4, c5 = in_channels
443
+ c_mid = c4
444
+
445
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
446
+
447
+ self.c_h = int(c_mid * c_h)
448
+ self.c_l = int(c_mid * c_l)
449
+ self.c_s = c_mid - self.c_h - self.c_l
450
+ assert self.c_s > 0, "Channel split error"
451
+
452
+ self.high_order_branch = nn.ModuleList(
453
+ [C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0) for _ in range(k)]
454
+ )
455
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
456
+
457
+ self.low_order_branch = nn.Sequential(
458
+ *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
459
+ )
460
+
461
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
462
+
463
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
464
+ B2, B3, B4, B5 = x
465
+
466
+ B, _, H4, W4 = B4.shape
467
+
468
+ B2_resized = F.interpolate(B2, size=(H4, W4), mode='bilinear', align_corners=False)
469
+ B3_resized = F.interpolate(B3, size=(H4, W4), mode='bilinear', align_corners=False)
470
+ B5_resized = F.interpolate(B5, size=(H4, W4), mode='bilinear', align_corners=False)
471
+
472
+ x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
473
+
474
+ x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
475
+
476
+ x_h_outs = [m(x_h) for m in self.high_order_branch]
477
+ x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
478
+
479
+ x_l_out = self.low_order_branch(x_l)
480
+
481
+ y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
482
+
483
+ return y
484
+
485
+ class GatedFusion(nn.Module):
486
+ def __init__(self, in_channels):
487
+ super().__init__()
488
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
489
+
490
+ def forward(self, f_in, h):
491
+ if f_in.shape[1] != h.shape[1]:
492
+ raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
493
+ return f_in + self.gamma * h
494
+
495
+
496
+ class Backbone(nn.Module):
497
+ def __init__(self, in_channels=256, base_channels=64, base_depth=3):
498
+ super().__init__()
499
+ c = base_channels
500
+ c2 = base_channels
501
+ c3 = 256
502
+ c4 = 384
503
+ c5 = 512
504
+ c6 = 768
505
+
506
+ self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
507
+
508
+ self.p2 = nn.Sequential(
509
+ DSConv(c2, c3, k=3, s=(2, 1), p=1),
510
+ DS_C3k2(c3, c3, n=base_depth)
511
+ )
512
+
513
+ self.p3 = nn.Sequential(
514
+ DSConv(c3, c4, k=3, s=(2, 1), p=1),
515
+ DS_C3k2(c4, c4, n=base_depth*2)
516
+ )
517
+
518
+ self.p4 = nn.Sequential(
519
+ DSConv(c4, c5, k=3, s=2, p=1),
520
+ DS_C3k2(c5, c5, n=base_depth*2)
521
+ )
522
+
523
+ self.p5 = nn.Sequential(
524
+ DSConv(c5, c6, k=3, s=2, p=1),
525
+ DS_C3k2(c6, c6, n=base_depth)
526
+ )
527
+
528
+ self.out_channels = [c3, c4, c5, c6]
529
+
530
+ def forward(self, x):
531
+ x = self.stem(x)
532
+ x2 = self.p2(x)
533
+ x3 = self.p3(x2)
534
+ x4 = self.p4(x3)
535
+ x5 = self.p5(x4)
536
+ return [x2, x3, x4, x5]
537
+
538
+ class Decoder(nn.Module):
539
+ def __init__(self, encoder_channels: List[int], hyperace_out_c: int, decoder_channels: List[int]):
540
+ super().__init__()
541
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
542
+ c_d2, c_d3, c_d4, c_d5 = decoder_channels
543
+
544
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
545
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
546
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
547
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
548
+
549
+ self.fusion_d5 = GatedFusion(c_d5)
550
+ self.fusion_d4 = GatedFusion(c_d4)
551
+ self.fusion_d3 = GatedFusion(c_d3)
552
+ self.fusion_d2 = GatedFusion(c_d2)
553
+
554
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
555
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
556
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
557
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
558
+
559
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
560
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
561
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
562
+
563
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
564
+
565
+ def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
566
+ p2, p3, p4, p5 = enc_feats
567
+
568
+ d5 = self.skip_p5(p5)
569
+ h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear'))
570
+ d5 = self.fusion_d5(d5, h_d5)
571
+
572
+ d5_up = F.interpolate(d5, size=p4.shape[2:], mode='bilinear')
573
+ d4_skip = self.skip_p4(p4)
574
+ d4 = self.up_d5(d5_up) + d4_skip
575
+
576
+ h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear'))
577
+ d4 = self.fusion_d4(d4, h_d4)
578
+
579
+ d4_up = F.interpolate(d4, size=p3.shape[2:], mode='bilinear')
580
+ d3_skip = self.skip_p3(p3)
581
+ d3 = self.up_d4(d4_up) + d3_skip
582
+
583
+ h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear'))
584
+ d3 = self.fusion_d3(d3, h_d3)
585
+
586
+ d3_up = F.interpolate(d3, size=p2.shape[2:], mode='bilinear')
587
+ d2_skip = self.skip_p2(p2)
588
+ d2 = self.up_d3(d3_up) + d2_skip
589
+
590
+ h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear'))
591
+ d2 = self.fusion_d2(d2, h_d2)
592
+
593
+ d2_final = self.final_d2(d2)
594
+
595
+ return d2_final
596
+
597
+ class TFC_TDF(nn.Module):
598
+ def __init__(self, in_c, c, l, f, bn=4):
599
+ super().__init__()
600
+
601
+ self.blocks = nn.ModuleList()
602
+ for i in range(l):
603
+ block = nn.Module()
604
+
605
+ block.tfc1 = nn.Sequential(
606
+ nn.InstanceNorm2d(in_c, affine=True, eps=1e-8),
607
+ nn.SiLU(),
608
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
609
+ )
610
+ block.tdf = nn.Sequential(
611
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
612
+ nn.SiLU(),
613
+ nn.Linear(f, f // bn, bias=False),
614
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
615
+ nn.SiLU(),
616
+ nn.Linear(f // bn, f, bias=False),
617
+ )
618
+ block.tfc2 = nn.Sequential(
619
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
620
+ nn.SiLU(),
621
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
622
+ )
623
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
624
+
625
+ self.blocks.append(block)
626
+ in_c = c
627
+
628
+ def forward(self, x):
629
+ for block in self.blocks:
630
+ s = block.shortcut(x)
631
+ x = block.tfc1(x)
632
+ x = x + block.tdf(x)
633
+ x = block.tfc2(x)
634
+ x = x + s
635
+ return x
636
+
637
+ class FreqPixelShuffle(nn.Module):
638
+ def __init__(self, in_channels, out_channels, scale, f):
639
+ super().__init__()
640
+ self.scale = scale
641
+ self.conv = DSConv(in_channels, out_channels * scale)
642
+ self.out_conv = TFC_TDF(out_channels, out_channels, 2, f)
643
+
644
+ def forward(self, x):
645
+ x = self.conv(x)
646
+ B, C_r, H, W = x.shape
647
+ out_c = C_r // self.scale
648
+
649
+ x = x.view(B, out_c, self.scale, H, W)
650
+
651
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
652
+ x = x.view(B, out_c, H, W * self.scale)
653
+
654
+ return self.out_conv(x)
655
+
656
+ class ProgressiveUpsampleHead(nn.Module):
657
+ def __init__(self, in_channels, out_channels, target_bins=1025, in_bands=62):
658
+ super().__init__()
659
+ self.target_bins = target_bins
660
+
661
+ c = in_channels
662
+
663
+ self.block1 = FreqPixelShuffle(c, c//2, scale=2, f=in_bands*2)
664
+ self.block2 = FreqPixelShuffle(c//2, c//4, scale=2, f=in_bands*4)
665
+ self.block3 = FreqPixelShuffle(c//4, c//8, scale=2, f=in_bands*8)
666
+ self.block4 = FreqPixelShuffle(c//8, c//16, scale=2, f=in_bands*16)
667
+
668
+ self.final_conv = nn.Conv2d(c//16, out_channels, kernel_size=3, stride=1, padding='same', bias=False)
669
+
670
+ def forward(self, x):
671
+
672
+ x = self.block1(x)
673
+ x = self.block2(x)
674
+ x = self.block3(x)
675
+ x = self.block4(x)
676
+
677
+ if x.shape[-1] != self.target_bins:
678
+ x = F.interpolate(x, size=(x.shape[2], self.target_bins), mode='bilinear', align_corners=False)
679
+
680
+ x = self.final_conv(x)
681
+ return x
682
+
683
+ class SegmModel(nn.Module):
684
+ def __init__(self, in_bands=62, in_dim=256, out_bins=1025, out_channels=4,
685
+ base_channels=64, base_depth=2,
686
+ num_hyperedges=32, num_heads=8):
687
+ super().__init__()
688
+
689
+ self.backbone = Backbone(in_channels=in_dim, base_channels=base_channels, base_depth=base_depth)
690
+ enc_channels = self.backbone.out_channels
691
+ c2, c3, c4, c5 = enc_channels
692
+
693
+ hyperace_in_channels = enc_channels
694
+ hyperace_out_channels = c4
695
+ self.hyperace = HyperACE(
696
+ hyperace_in_channels, hyperace_out_channels,
697
+ num_hyperedges, num_heads, k=2, l=1
698
+ )
699
+
700
+ decoder_channels = [c2, c3, c4, c5]
701
+ self.decoder = Decoder(
702
+ enc_channels, hyperace_out_channels, decoder_channels
703
+ )
704
+
705
+ self.upsample_head = ProgressiveUpsampleHead(
706
+ in_channels=decoder_channels[0],
707
+ out_channels=out_channels,
708
+ target_bins=out_bins,
709
+ in_bands=in_bands
710
+ )
711
+
712
+ def forward(self, x):
713
+ H, W = x.shape[2:]
714
+
715
+ enc_feats = self.backbone(x)
716
+
717
+ h_ace_feats = self.hyperace(enc_feats)
718
+
719
+ dec_feat = self.decoder(enc_feats, h_ace_feats)
720
+
721
+ feat_time_restored = F.interpolate(dec_feat, size=(H, dec_feat.shape[-1]), mode='bilinear', align_corners=False)
722
+
723
+ out = self.upsample_head(feat_time_restored)
724
+
725
+ return out
726
+
727
+ def MLP(
728
+ dim_in,
729
+ dim_out,
730
+ dim_hidden=None,
731
+ depth=1,
732
+ activation=nn.Tanh
733
+ ):
734
+ dim_hidden = default(dim_hidden, dim_in)
735
+
736
+ net = []
737
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
738
+
739
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
740
+ is_last = ind == (len(dims) - 2)
741
+
742
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
743
+
744
+ if is_last:
745
+ continue
746
+
747
+ net.append(activation())
748
+
749
+ return nn.Sequential(*net)
750
+
751
+ class MaskEstimator(Module):
752
+ @beartype
753
+ def __init__(
754
+ self,
755
+ dim,
756
+ dim_inputs: Tuple[int, ...],
757
+ depth,
758
+ mlp_expansion_factor=4
759
+ ):
760
+ super().__init__()
761
+ self.dim_inputs = dim_inputs
762
+ self.to_freqs = ModuleList([])
763
+ dim_hidden = dim * mlp_expansion_factor
764
+
765
+ for dim_in in dim_inputs:
766
+ net = []
767
+
768
+ mlp = nn.Sequential(
769
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
770
+ nn.GLU(dim=-1)
771
+ )
772
+
773
+ self.to_freqs.append(mlp)
774
+
775
+ self.segm = SegmModel(in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs)//4)
776
+
777
+ def forward(self, x):
778
+ y = rearrange(x, 'b t f c -> b c t f')
779
+ y = self.segm(y)
780
+ y = rearrange(y, 'b c t f -> b t (f c)')
781
+
782
+ x = x.unbind(dim=-2)
783
+
784
+ outs = []
785
+
786
+ for band_features, mlp in zip(x, self.to_freqs):
787
+ freq_out = mlp(band_features)
788
+ outs.append(freq_out)
789
+
790
+ return torch.cat(outs, dim=-1) + y
791
+
792
+
793
+ # main class
794
+
795
+ DEFAULT_FREQS_PER_BANDS = (
796
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
797
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
798
+ 2, 2, 2, 2,
799
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
800
+ 12, 12, 12, 12, 12, 12, 12, 12,
801
+ 24, 24, 24, 24, 24, 24, 24, 24,
802
+ 48, 48, 48, 48, 48, 48, 48, 48,
803
+ 128, 129,
804
+ )
805
+
806
+ class BSRoformer(Module):
807
+
808
+ @beartype
809
+ def __init__(
810
+ self,
811
+ dim,
812
+ *,
813
+ depth,
814
+ stereo=False,
815
+ num_stems=1,
816
+ time_transformer_depth=2,
817
+ freq_transformer_depth=2,
818
+ linear_transformer_depth=0,
819
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
820
+ # in the paper, they divide into ~60 bands, test with 1 for starters
821
+ dim_head=64,
822
+ heads=8,
823
+ attn_dropout=0.,
824
+ ff_dropout=0.,
825
+ flash_attn=True,
826
+ dim_freqs_in=1025,
827
+ stft_n_fft=2048,
828
+ stft_hop_length=512,
829
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
830
+ stft_win_length=2048,
831
+ stft_normalized=False,
832
+ stft_window_fn: Optional[Callable] = None,
833
+ mask_estimator_depth=2,
834
+ multi_stft_resolution_loss_weight=1.,
835
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
836
+ multi_stft_hop_size=147,
837
+ multi_stft_normalized=False,
838
+ multi_stft_window_fn: Callable = torch.hann_window,
839
+ mlp_expansion_factor=4,
840
+ use_torch_checkpoint=False,
841
+ skip_connection=False,
842
+ sage_attention=False,
843
+ ):
844
+ super().__init__()
845
+
846
+ self.stereo = stereo
847
+ self.audio_channels = 2 if stereo else 1
848
+ self.num_stems = num_stems
849
+ self.use_torch_checkpoint = use_torch_checkpoint
850
+ self.skip_connection = skip_connection
851
+
852
+ self.layers = ModuleList([])
853
+
854
+ if sage_attention:
855
+ print("Use Sage Attention")
856
+
857
+ transformer_kwargs = dict(
858
+ dim=dim,
859
+ heads=heads,
860
+ dim_head=dim_head,
861
+ attn_dropout=attn_dropout,
862
+ ff_dropout=ff_dropout,
863
+ flash_attn=flash_attn,
864
+ norm_output=False,
865
+ sage_attention=sage_attention,
866
+ )
867
+
868
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
869
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
870
+
871
+ for _ in range(depth):
872
+ tran_modules = []
873
+ tran_modules.append(
874
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
875
+ )
876
+ tran_modules.append(
877
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
878
+ )
879
+ self.layers.append(nn.ModuleList(tran_modules))
880
+
881
+ self.final_norm = RMSNorm(dim)
882
+
883
+ self.stft_kwargs = dict(
884
+ n_fft=stft_n_fft,
885
+ hop_length=stft_hop_length,
886
+ win_length=stft_win_length,
887
+ normalized=stft_normalized
888
+ )
889
+
890
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
891
+
892
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
893
+
894
+ assert len(freqs_per_bands) > 1
895
+ assert sum(
896
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
897
+
898
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
899
+
900
+ self.band_split = BandSplit(
901
+ dim=dim,
902
+ dim_inputs=freqs_per_bands_with_complex
903
+ )
904
+
905
+ self.mask_estimators = nn.ModuleList([])
906
+
907
+ for _ in range(num_stems):
908
+ mask_estimator = MaskEstimator(
909
+ dim=dim,
910
+ dim_inputs=freqs_per_bands_with_complex,
911
+ depth=mask_estimator_depth,
912
+ mlp_expansion_factor=mlp_expansion_factor,
913
+ )
914
+
915
+ self.mask_estimators.append(mask_estimator)
916
+
917
+ # for the multi-resolution stft loss
918
+
919
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
920
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
921
+ self.multi_stft_n_fft = stft_n_fft
922
+ self.multi_stft_window_fn = multi_stft_window_fn
923
+
924
+ self.multi_stft_kwargs = dict(
925
+ hop_length=multi_stft_hop_size,
926
+ normalized=multi_stft_normalized
927
+ )
928
+
929
+ def forward(
930
+ self,
931
+ raw_audio,
932
+ target=None,
933
+ return_loss_breakdown=False
934
+ ):
935
+ """
936
+ einops
937
+
938
+ b - batch
939
+ f - freq
940
+ t - time
941
+ s - audio channel (1 for mono, 2 for stereo)
942
+ n - number of 'stems'
943
+ c - complex (2)
944
+ d - feature dimension
945
+ """
946
+
947
+ device = raw_audio.device
948
+
949
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
950
+ x_is_mps = True if device.type == "mps" else False
951
+
952
+ if raw_audio.ndim == 2:
953
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
954
+
955
+ channels = raw_audio.shape[1]
956
+ assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
957
+
958
+ # to stft
959
+
960
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
961
+
962
+ stft_window = self.stft_window_fn(device=device)
963
+
964
+ # RuntimeError: FFT operations are only supported on MacOS 14+
965
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
966
+ try:
967
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
968
+ except:
969
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs,
970
+ window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(
971
+ device)
972
+ stft_repr = torch.view_as_real(stft_repr)
973
+
974
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
975
+
976
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
977
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
978
+
979
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
980
+
981
+
982
+ x = self.band_split(x)
983
+
984
+ # axial / hierarchical attention
985
+
986
+ for i, transformer_block in enumerate(self.layers):
987
+
988
+
989
+ time_transformer, freq_transformer = transformer_block
990
+
991
+
992
+ x = rearrange(x, 'b t f d -> b f t d')
993
+ x, ps = pack([x], '* t d')
994
+
995
+
996
+ x = time_transformer(x)
997
+
998
+ x, = unpack(x, ps, '* t d')
999
+ x = rearrange(x, 'b f t d -> b t f d')
1000
+ x, ps = pack([x], '* f d')
1001
+
1002
+
1003
+ x = freq_transformer(x)
1004
+
1005
+ x, = unpack(x, ps, '* f d')
1006
+
1007
+
1008
+ x = self.final_norm(x)
1009
+
1010
+ num_stems = len(self.mask_estimators)
1011
+
1012
+
1013
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1014
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
1015
+
1016
+ # modulate frequency representation
1017
+
1018
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
1019
+
1020
+ stft_repr = torch.view_as_complex(stft_repr)
1021
+ mask = torch.view_as_complex(mask)
1022
+
1023
+ stft_repr = stft_repr * mask
1024
+
1025
+ # istft
1026
+
1027
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
1028
+
1029
+ try:
1030
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
1031
+ except:
1032
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
1033
+
1034
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
1035
+
1036
+ if num_stems == 1:
1037
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
1038
+
1039
+ # if a target is passed in, calculate loss for learning
1040
+
1041
+ if not exists(target):
1042
+ return recon_audio
1043
+
1044
+ if self.num_stems > 1:
1045
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
1046
+
1047
+ if target.ndim == 2:
1048
+ target = rearrange(target, '... t -> ... 1 t')
1049
+
1050
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
1051
+
1052
+ loss = F.l1_loss(recon_audio, target)
1053
+
1054
+ multi_stft_resolution_loss = 0.
1055
+
1056
+ for window_size in self.multi_stft_resolutions_window_sizes:
1057
+ res_stft_kwargs = dict(
1058
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
1059
+ win_length=window_size,
1060
+ return_complex=True,
1061
+ window=self.multi_stft_window_fn(window_size, device=device),
1062
+ **self.multi_stft_kwargs,
1063
+ )
1064
+
1065
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
1066
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
1067
+
1068
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
1069
+
1070
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1071
+
1072
+ total_loss = loss + weighted_multi_resolution_loss
1073
+
1074
+ if not return_loss_breakdown:
1075
+ return total_loss
1076
+
1077
+ return total_loss, (loss, multi_stft_resolution_loss)
v2_voc/bs_roformer_voc_hyperacev2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54cf516f621f2f460bf660ed137e244b8931bf7a2ce85ddceecff816dbc4d668
3
+ size 288724430
v2_voc/config.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audio:
2
+ chunk_size: 960000
3
+ dim_f: 1024
4
+ dim_t: 801 # don't work (use in model)
5
+ hop_length: 441 # don't work (use in model)
6
+ n_fft: 2048
7
+ num_channels: 2
8
+ sample_rate: 44100
9
+ min_mean_abs: 0.0001
10
+
11
+ model:
12
+ dim: 256
13
+ depth: 12
14
+ stereo: true
15
+ num_stems: 1
16
+ time_transformer_depth: 1
17
+ freq_transformer_depth: 1
18
+ linear_transformer_depth: 0
19
+ freqs_per_bands: !!python/tuple
20
+ - 2
21
+ - 2
22
+ - 2
23
+ - 2
24
+ - 2
25
+ - 2
26
+ - 2
27
+ - 2
28
+ - 2
29
+ - 2
30
+ - 2
31
+ - 2
32
+ - 2
33
+ - 2
34
+ - 2
35
+ - 2
36
+ - 2
37
+ - 2
38
+ - 2
39
+ - 2
40
+ - 2
41
+ - 2
42
+ - 2
43
+ - 2
44
+ - 4
45
+ - 4
46
+ - 4
47
+ - 4
48
+ - 4
49
+ - 4
50
+ - 4
51
+ - 4
52
+ - 4
53
+ - 4
54
+ - 4
55
+ - 4
56
+ - 12
57
+ - 12
58
+ - 12
59
+ - 12
60
+ - 12
61
+ - 12
62
+ - 12
63
+ - 12
64
+ - 24
65
+ - 24
66
+ - 24
67
+ - 24
68
+ - 24
69
+ - 24
70
+ - 24
71
+ - 24
72
+ - 48
73
+ - 48
74
+ - 48
75
+ - 48
76
+ - 48
77
+ - 48
78
+ - 48
79
+ - 48
80
+ - 128
81
+ - 129
82
+ dim_head: 64
83
+ heads: 8
84
+ attn_dropout: 0.0
85
+ ff_dropout: 0.0
86
+ flash_attn: true
87
+ dim_freqs_in: 1025
88
+ stft_n_fft: 2048
89
+ stft_hop_length: 512
90
+ stft_win_length: 2048
91
+ stft_normalized: false
92
+ mask_estimator_depth: 2
93
+ multi_stft_resolution_loss_weight: 1.0
94
+ multi_stft_resolutions_window_sizes: !!python/tuple
95
+ - 4096
96
+ - 2048
97
+ - 1024
98
+ - 512
99
+ - 256
100
+ multi_stft_hop_size: 147
101
+ multi_stft_normalized: False
102
+ mlp_expansion_factor: 4
103
+ use_torch_checkpoint: True
104
+ skip_connection: False
105
+
106
+
107
+ training:
108
+ batch_size: 1
109
+ gradient_accumulation_steps: 1
110
+ grad_clip: 0
111
+ instruments: ['vocals', 'instrument']
112
+ lr: 1.0e-5
113
+ patience: 5
114
+ reduce_factor: 0.9
115
+ target_instrument: vocals
116
+ num_epochs: 1000
117
+ num_steps: 1000
118
+ q: 0.95
119
+ coarse_loss_clip: true
120
+ ema_momentum: 0.999
121
+ optimizer: adam
122
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
123
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
124
+
125
+
126
+ inference:
127
+ batch_size: 2
128
+ dim_t: 1876
129
+ num_overlap: 4