Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import os | |
| from tokenizers import Tokenizer | |
| from typing import Optional, Tuple, List, Iterator | |
| # ============================================================================ | |
| # 1. MODEL ARCHITECTURE (Exact match for i3HybridChatModel from fine-tune.py) | |
| # ============================================================================ | |
| def rwkv_linear_attention(B: int, T: int, C: int, | |
| r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
| w: torch.Tensor, u: torch.Tensor): | |
| # Core RWKV-7 Linear Attention Kernel | |
| y = torch.zeros_like(v) | |
| state_aa = torch.zeros(B, C, device=v.device) | |
| state_bb = torch.zeros(B, C, device=v.device) | |
| state_pp = torch.ones(B, C, device=v.device) * -1e30 | |
| for t in range(T): | |
| rt = r[:, t, :] | |
| kt = k[:, t, :] | |
| vt = v[:, t, :] | |
| ww = kt + u | |
| p = torch.maximum(state_pp, ww) | |
| e1 = torch.exp(state_pp - p) | |
| e2 = torch.exp(ww - p) | |
| y[:, t, :] = (e1 * state_aa + e2 * vt) / (e1 * state_bb + e2) | |
| ww = w + state_pp | |
| p = torch.maximum(ww, kt) | |
| e1 = torch.exp(ww - p) | |
| e2 = torch.exp(kt - p) | |
| state_aa = e1 * state_aa + e2 * vt | |
| state_bb = e1 * state_bb + e2 | |
| state_pp = p | |
| return y | |
| class TimeMixing(nn.Module): | |
| def __init__(self, n_embd, layer_id): | |
| super().__init__() | |
| self.n_embd = n_embd | |
| self.layer_id = layer_id | |
| self.time_decay = nn.Parameter(torch.ones(n_embd)) | |
| self.time_first = nn.Parameter(torch.ones(n_embd) * 0.5) | |
| self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) | |
| self.time_mix_v = nn.Parameter(torch.ones(1, 1, n_embd)) | |
| self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) | |
| self.key = nn.Linear(n_embd, n_embd, bias=False) | |
| self.value = nn.Linear(n_embd, n_embd, bias=False) | |
| self.receptance = nn.Linear(n_embd, n_embd, bias=False) | |
| self.output = nn.Linear(n_embd, n_embd, bias=False) | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| # Time mix Shift logic (Simplified for inference) | |
| xx = F.pad(x, (0, 0, 1, -1)) | |
| k = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| v = x * self.time_mix_v + xx * (1 - self.time_mix_v) | |
| r = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| k = self.key(k) | |
| v = self.value(v) | |
| r = torch.sigmoid(self.receptance(r)) | |
| w = -torch.exp(self.time_decay) | |
| u = self.time_first | |
| return self.output(r * rwkv_linear_attention(B, T, C, r, k, v, w, u)) | |
| class ChannelMixing(nn.Module): | |
| def __init__(self, n_embd, layer_id): | |
| super().__init__() | |
| self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) | |
| self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) | |
| self.key = nn.Linear(n_embd, n_embd * 4, bias=False) | |
| self.receptance = nn.Linear(n_embd, n_embd, bias=False) | |
| self.value = nn.Linear(n_embd * 4, n_embd, bias=False) | |
| def forward(self, x): | |
| xx = F.pad(x, (0, 0, 1, -1)) | |
| k = x * self.time_mix_k + xx * (1 - self.time_mix_k) | |
| r = x * self.time_mix_r + xx * (1 - self.time_mix_r) | |
| k = torch.square(torch.relu(self.key(k))) | |
| r = torch.sigmoid(self.receptance(r)) | |
| return r * self.value(k) | |
| class StandardAttention(nn.Module): | |
| def __init__(self, n_embd): | |
| super().__init__() | |
| self.qkv = nn.Linear(n_embd, n_embd * 3) | |
| self.out_proj = nn.Linear(n_embd, n_embd) | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| q, k, v = self.qkv(x).chunk(3, dim=-1) | |
| q = q.view(B, T, 1, C).transpose(1, 2) | |
| k = k.view(B, T, 1, C).transpose(1, 2) | |
| v = v.view(B, T, 1, C).transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| y = y.transpose(1, 2).reshape(B, T, C) | |
| return self.out_proj(y) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd, layer_id, is_attn=False): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| if is_attn: | |
| self.attn = StandardAttention(n_embd) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(n_embd, n_embd * 4), | |
| nn.GELU(), | |
| nn.Linear(n_embd * 4, n_embd) | |
| ) | |
| else: | |
| self.att = TimeMixing(n_embd, layer_id) | |
| self.ffn = ChannelMixing(n_embd, layer_id) | |
| def forward(self, x): | |
| if hasattr(self, 'att'): | |
| x = x + self.att(self.ln1(x)) | |
| else: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.ffn(self.ln2(x)) | |
| return x | |
| class I3Model(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embed = nn.Embedding(config['vocab_size'], config['n_embd']) | |
| # Build hybrid layers: RWKV first, then Attention | |
| self.layers = nn.ModuleList() | |
| for i in range(config['n_layer']): | |
| is_attn = i >= config['rwkv_layers'] | |
| self.layers.append(Block(config['n_embd'], i, is_attn=is_attn)) | |
| self.ln_f = nn.LayerNorm(config['n_embd']) | |
| self.head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False) | |
| def forward(self, idx): | |
| x = self.embed(idx) | |
| for layer in self.layers: | |
| x = layer(x) | |
| x = self.ln_f(x) | |
| return self.head(x) | |
| def generate_stream(self, idx, max_new_tokens, temperature=1.0, top_p=0.9, eos_id=0): | |
| """Generator yielding token IDs one by one.""" | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx if idx.size(1) <= self.config['ctx_len'] else idx[:, -self.config['ctx_len']:] | |
| logits = self(idx_cond) | |
| logits = logits[:, -1, :] / max(temperature, 1e-5) | |
| # Nucleus Sampling | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[:, indices_to_remove] = -float('Inf') | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| yield idx_next.item() | |
| if idx_next.item() == eos_id: | |
| break | |
| # ============================================================================ | |
| # 2. LOADING & INFERENCE SETUP | |
| # ============================================================================ | |
| REPO_ID = "i3-lab/i3-4096ctx-chat" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(): | |
| print(f"Loading i3-Hybrid model from {REPO_ID}...") | |
| # Download essentials | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") | |
| config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") | |
| tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json") | |
| with open(config_path, 'r') as f: | |
| raw_config = json.load(f) | |
| # Map configuration from fine-tune.py structure | |
| config = { | |
| 'n_embd': raw_config['d_model'], | |
| 'n_layer': raw_config['n_layers'], | |
| 'rwkv_layers': raw_config.get('rwkv_layers', 12), | |
| 'vocab_size': raw_config['vocab_size'], | |
| 'ctx_len': raw_config.get('inference_context_window', 4096), | |
| 'special_tokens': raw_config.get('special_tokens', {}) | |
| } | |
| tokenizer = Tokenizer.from_file(tokenizer_path) | |
| model = I3Model(config) | |
| # Load weights | |
| state_dict = torch.load(model_path, map_location=device) | |
| if 'model_state_dict' in state_dict: | |
| state_dict = state_dict['model_state_dict'] | |
| # strict=False allows ignoring auxiliary 'compressor' weights in the state_dict | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, config | |
| try: | |
| model, tokenizer, config = load_model() | |
| # Special Tokens | |
| U_TOKEN = config['special_tokens'].get('user_token', '<|user|>') | |
| A_TOKEN = config['special_tokens'].get('assistant_token', '<|assistant|>') | |
| E_TOKEN = config['special_tokens'].get('eos_token', '<EOS>') | |
| EOS_ID = tokenizer.token_to_id(E_TOKEN) or 0 | |
| print("✅ i3-Hybrid System Ready!") | |
| except Exception as e: | |
| print(f"❌ Initialization failed: {e}") | |
| model, tokenizer = None, None | |
| def chat_response(message, history, temperature, max_tokens, top_p): | |
| if model is None: | |
| yield "Error: Model failed to load. Check server logs." | |
| return | |
| # Format Chat ML style prompt | |
| prompt = "" | |
| for user_msg, assistant_msg in history: | |
| if user_msg and assistant_msg: | |
| prompt += f"{U_TOKEN}{user_msg}{E_TOKEN}{A_TOKEN}{assistant_msg}{E_TOKEN}" | |
| prompt += f"{U_TOKEN}{message}{E_TOKEN}{A_TOKEN}" | |
| input_ids = torch.tensor([tokenizer.encode(prompt).ids], device=device) | |
| generated_text = "" | |
| # Use the generator for streaming | |
| for token_id in model.generate_stream( | |
| input_ids, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_id=EOS_ID | |
| ): | |
| word = tokenizer.decode([token_id]) | |
| generated_text += word | |
| yield generated_text.strip() | |
| # ============================================================================ | |
| # 3. GRADIO UI | |
| # ============================================================================ | |
| demo = gr.ChatInterface( | |
| fn=chat_response, | |
| title="i3-Hybrid Chat", | |
| description="Inference for the i3-4096ctx RWKV/Attention Hybrid model.", | |
| examples=[ | |
| ["Hello! Who are you?", 0.7, 512, 0.9], | |
| ["Write a Python script to sort a list.", 0.5, 1024, 0.9], | |
| ["Tell me a joke.", 1.0, 128, 0.8] | |
| ], | |
| additional_inputs=[ | |
| gr.Slider(0.1, 1.5, value=0.7, label="Temperature"), | |
| gr.Slider(64, 2048, value=512, step=64, label="Max New Tokens"), | |
| gr.Slider(0.5, 1.0, value=0.9, label="Top-P"), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |