Spaces:
Paused
Paused
| from typing import Optional, Tuple | |
| import torch as T | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ioblocks import GaussianMixtureIOLayer, FSQ | |
| from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm | |
| from tokenizer import make_tokenizer | |
| from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored | |
| from utils import load_ckpt | |
| class LatentQuantizer(nn.Module): | |
| class Config: | |
| compressor_config: Optional[FSQ.Config] = None | |
| dim: Optional[int] = None | |
| ff_dim: Optional[int] = None | |
| input_dim: int = None | |
| from_pretrained: Optional[Tuple[str, str]] = None | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| if exists(c.from_pretrained): | |
| checkpoint = load_ckpt(*c.from_pretrained) | |
| else: | |
| assert exists(c.compressor_config), f'hmm {c}' | |
| self.compressor = c.compressor_config() | |
| self.ffnn = FFNN(c.dim, c.ff_dim) | |
| self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity() | |
| if exists(c.from_pretrained): | |
| self.load_state_dict(checkpoint) | |
| def forward(self, x, return_latent=False, known_latent=None): | |
| """ | |
| x: (B, S, D) | |
| """ | |
| if exists(known_latent): | |
| return self.compressor.indices_to_codes(known_latent) | |
| x = self.input(x) | |
| x = self.ffnn(x) | |
| x, tokens = self.compressor(x) | |
| if return_latent: | |
| return x, tokens | |
| return x | |
| class TransformerVAE(nn.Module): | |
| class Config: | |
| io_config: Optional[GaussianMixtureIOLayer.Config] = None | |
| stack_config: Optional[Stack.Config] = None | |
| quantizer_config: Optional[LatentQuantizer.Config] = None | |
| plex_layer: int = None | |
| plex_roll: int = 1 | |
| split: bool = True | |
| from_pretrained: Optional[Tuple[str, str]] = None | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| if exists(c.from_pretrained): | |
| checkpoint = load_ckpt(*c.from_pretrained) | |
| else: | |
| assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}' | |
| self.io = c.io_config() | |
| self.stack = c.stack_config() | |
| self.plex_layer = c.stack_config.layers//2 | |
| self.plex_roll = c.plex_roll | |
| self.plex_dim = c.quantizer_config.dim | |
| assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}' | |
| self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim) | |
| self.out_norm = Norm(c.stack_config.dim) | |
| if c.split: | |
| self.io2 = c.io_config() | |
| self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim) | |
| self.io2.fc_loc = None | |
| self.io2.fc_scale = None | |
| self.io2.fc_weight = None | |
| kv_heads = c.stack_config.kv_heads or c.stack_config.n_head | |
| head_dim = c.stack_config.dim // c.stack_config.n_head | |
| self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0) | |
| cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim] | |
| self.cache_shape = cache_shape | |
| self.cache = [None] * self.cache_num_layers | |
| if exists(c.from_pretrained): | |
| result = self.load_state_dict(checkpoint, strict=False) | |
| print0_colored(result, 'yellow') | |
| self.quantizer = c.quantizer_config().eval() | |
| self.quantizer.requires_grad = False | |
| def quantize(self, x): | |
| if self.c.split: | |
| x1, x2 = x.chunk(2, dim=-1) | |
| with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
| quantized1 = self.quantizer(x1) | |
| quantized2 = self.quantizer(x2) | |
| return quantized1, quantized2 | |
| else: | |
| with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
| return self.quantizer(x) | |
| def untokenize(self, token_data): | |
| return self.quantizer(None, known_latent=token_data) | |
| def init_cache(self, bsize, device, dtype, length:int=None): | |
| cache_shape = self.cache_shape.copy() | |
| cache_shape[1] = length or cache_shape[1] | |
| self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) | |
| def deinit_cache(self): | |
| self.cache = [None] * self.cache_num_layers | |
| def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None): | |
| if self.c.split: | |
| x1, x2 = data.chunk(2, dim=-1) | |
| x = self.io.input(x1) + self.io2.input(x2) | |
| else: | |
| x = self.io.input(data) | |
| cache_idx = 0 | |
| for l, layer in enumerate(self.stack.layers): | |
| if l == self.plex_layer: | |
| if self.c.split: | |
| plex1, plex2 = self.quantize(data) | |
| plex1 = T.roll(plex1, -self.c.plex_roll, dims=1) | |
| plex2 = T.roll(plex2, -self.c.plex_roll, dims=1) | |
| if exists(next_tokens): | |
| plex1[:, -1:] = self.untokenize(next_tokens[0]) | |
| plex2[:, -1:] = self.untokenize(next_tokens[1]) | |
| x1 = x + self.plex_projection(plex1) | |
| x2 = x + self.plex_projection2(plex2) | |
| else: | |
| plex = self.quantize(data) | |
| plex = T.roll(plex, -self.c.plex_roll, dims=1) | |
| if exists(next_tokens): | |
| plex[:, -1:] = self.untokenize(next_tokens) | |
| x = x + self.plex_projection(plex) | |
| if l < self.plex_layer: | |
| x = layer(x, kv=self.cache[l]) | |
| else: | |
| if self.c.split: | |
| x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx]) | |
| cache_idx += 1 | |
| x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx]) | |
| cache_idx += 1 | |
| else: | |
| x = layer(x, kv=self.cache[l]) | |
| with T.autocast(device_type='cuda', dtype=T.bfloat16): | |
| if self.c.split: | |
| x1, x2 = self.out_norm(x1), self.out_norm(x2) | |
| out1, out2 = self.io.output(x1), self.io.output(x2) | |
| else: | |
| x = self.out_norm(x) | |
| out = self.io.output(x) | |
| if isnt(temps): | |
| if self.c.split: | |
| return out1, out2 | |
| else: | |
| return out | |
| else: | |
| if self.c.split: | |
| next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :] | |
| next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :] | |
| next_data = T.cat([next_data1, next_data2], dim=-1) | |
| return next_data | |
| else: | |
| next_data = self.io.temp_sample(out, temps)[:, -1:, :] | |
| return next_data | |
| class HertzDevModel(nn.Module): | |
| class Config: | |
| dim: int | |
| vocab_size: int | |
| stack_config: Optional[Stack.Config] = None | |
| latent_size: int = 32 | |
| split: bool = True | |
| quantizer_config: Optional[LatentQuantizer.Config] = None | |
| resynthesizer_config: Optional[TransformerVAE.Config] = None | |
| from_pretrained: Optional[Tuple[str, str]] = None | |
| def __init__(self, c: Config): | |
| super().__init__() | |
| if exists(c.from_pretrained): | |
| checkpoint = load_ckpt(*c.from_pretrained) | |
| else: | |
| assert (exists(c.stack_config)), f'hmm {c}' | |
| self.input = nn.Linear(c.latent_size, c.dim) | |
| if self.c.split: | |
| self.input2 = nn.Linear(c.latent_size, c.dim) | |
| self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta) | |
| self.layers = nn.ModuleList([ | |
| PerfBlock( | |
| dim=c.stack_config.dim, | |
| layer_id=l, | |
| n_head=c.stack_config.n_head, | |
| kv_heads=c.stack_config.kv_heads, | |
| ff_dim=c.stack_config.ff_dim, | |
| eps=c.stack_config.eps, | |
| shape_rotator=self.shape_rotator, | |
| ) for l in range(c.stack_config.layers) | |
| ]) | |
| self.output = GPTOutput(c.dim, c.vocab_size) | |
| if self.c.split: | |
| self.output2 = GPTOutput(c.dim, c.vocab_size) | |
| self.cache = [None] * c.stack_config.layers | |
| self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head | |
| self.head_dim = c.stack_config.dim // c.stack_config.n_head | |
| if exists(c.from_pretrained): | |
| result = self.load_state_dict(checkpoint, strict=False) | |
| print0_colored(result, 'yellow') | |
| self.resynthesizer = c.resynthesizer_config().eval() | |
| self.resynthesizer.requires_grad = False | |
| self.audio_tokenizer = make_tokenizer(device='cpu') | |
| self.audio_cache = None | |
| self.audio_latent_cache = None | |
| self.use_audio_cache = False | |
| def tokenize(self, audio_data): | |
| orig_audio_shape = audio_data.shape | |
| if exists(self.audio_cache): | |
| audio_data = T.cat([self.audio_cache, audio_data], dim=-1) | |
| self.audio_cache = audio_data[..., -(6*16_000):] | |
| elif self.use_audio_cache: | |
| self.audio_cache = audio_data[..., -(6*16_000):] | |
| if audio_data.shape[1] == 2: | |
| enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1]) | |
| enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2]) | |
| return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):] | |
| else: | |
| return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):] | |
| def untokenize(self, token_data): | |
| if exists(self.audio_latent_cache): | |
| token_data = T.cat([self.audio_latent_cache, token_data], dim=1) | |
| self.audio_latent_cache = token_data[:, -(6*8):] | |
| elif self.use_audio_cache: | |
| self.audio_latent_cache = token_data[:, -(6*8):] | |
| if token_data.shape[-1] == 2*self.c.latent_size: | |
| dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size]) | |
| dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:]) | |
| return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):] | |
| else: | |
| return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):] | |
| def init_cache(self, bsize, device, dtype, length:int=None): | |
| cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim] | |
| self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) | |
| self.resynthesizer.init_cache(bsize, device, dtype, length) | |
| self.use_audio_cache = True | |
| def deinit_cache(self): | |
| self.cache = [None] * len(self.layers) | |
| self.resynthesizer.deinit_cache() | |
| self.audio_cache = None | |
| self.audio_latent_cache = None | |
| self.use_audio_cache = False | |
| def forward(self, data): | |
| if self.c.split: | |
| x1, x2 = data.chunk(2, dim=-1) | |
| x = self.input(x1) + self.input2(x2) | |
| else: | |
| x = self.input(data) | |
| for l, layer in enumerate(self.layers): | |
| x = layer(x, kv=self.cache[l]) | |
| if self.c.split: | |
| return self.output(x), self.output2(x) | |
| else: | |
| return self.output(x) | |
| def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))): | |
| latents_in = self.tokenize(audio_data) | |
| next_latents = self.next_latent(latents_in, temps) | |
| next_model_latent = next_latents[..., self.c.latent_size:] | |
| audio_decoded = self.untokenize(next_model_latent)[..., -2000:] | |
| return audio_decoded | |
| def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))): | |
| if self.c.split: | |
| logits1, logits2 = self.forward(model_input) | |
| next_logits1 = logits1[:, -1] | |
| next_logits2 = logits2[:, -1] | |
| next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1) | |
| next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1) | |
| next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1]) | |
| else: | |
| logits = self.forward(model_input) | |
| next_logits = logits[:, -1] | |
| next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1) | |
| next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1]) | |
| return next_input | |
| def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor: | |
| """ | |
| only accepts latent-space data. | |
| """ | |
| if use_cache: | |
| self.init_cache(data.shape[0], data.device, T.bfloat16) | |
| next_input = generated = data | |
| target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len) | |
| for _ in tqdm0(range(data.shape[1], target_len)): | |
| model_input = next_input if use_cache else generated | |
| next_input = self.next_latent(model_input, temps) | |
| generated = T.cat([generated, next_input], dim=1) | |
| if use_cache: | |
| self.deinit_cache() | |
| return generated | |
| def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False): | |
| if is_split: | |
| checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')] | |
| elif not use_pure_audio_ablation: | |
| checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')] | |
| else: | |
| checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')] | |
| quantizer_config=LatentQuantizer.Config( | |
| from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'), | |
| compressor_config=FSQ.Config( | |
| levels=[8,8,8,8,8], | |
| dim=2048, | |
| num_codebooks=1, | |
| keep_num_codebooks_dim=None, | |
| scale=None, | |
| allowed_dtypes=['float32', 'float64', 'bfloat16'], | |
| channel_first=False, | |
| projection_has_bias=True, | |
| return_indices=True, | |
| force_quantization_f32=True, | |
| use_rms=False | |
| ), | |
| dim=2048, | |
| ff_dim=8192, | |
| input_dim=32 | |
| ) | |
| resynthesizer_config=TransformerVAE.Config( | |
| io_config=GaussianMixtureIOLayer.Config( | |
| latent_dim=32, | |
| dim=4096, | |
| num_components=8, | |
| ), | |
| stack_config=Stack.Config( | |
| layers=8, | |
| dim=4096, | |
| seq_len=8192, | |
| n_head=16, | |
| ff_dim=11008, | |
| kv_heads=16, | |
| eps=1e-5, | |
| theta=10_000 | |
| ), | |
| quantizer_config=quantizer_config, | |
| plex_layer=None, | |
| plex_roll=1, | |
| split=is_split, | |
| from_pretrained=checkpoints[0], | |
| ) | |
| return HertzDevModel.Config( | |
| dim=4096, | |
| vocab_size=32_768, | |
| stack_config=Stack.Config( | |
| layers=32, | |
| dim=4096, | |
| seq_len=2048, | |
| n_head=32, | |
| ff_dim=None, | |
| kv_heads=None, | |
| eps=1e-5, | |
| theta=10_000, | |
| ), | |
| quantizer_config=quantizer_config, | |
| resynthesizer_config=resynthesizer_config, | |
| split=is_split, | |
| from_pretrained=checkpoints[1], | |
| ) |