import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download import json import os from tokenizers import Tokenizer from typing import Optional, Tuple, List, Iterator # ============================================================================ # 1. MODEL ARCHITECTURE (Exact match for i3HybridChatModel from fine-tune.py) # ============================================================================ @torch.jit.script def rwkv_linear_attention(B: int, T: int, C: int, r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor): # Core RWKV-7 Linear Attention Kernel y = torch.zeros_like(v) state_aa = torch.zeros(B, C, device=v.device) state_bb = torch.zeros(B, C, device=v.device) state_pp = torch.ones(B, C, device=v.device) * -1e30 for t in range(T): rt = r[:, t, :] kt = k[:, t, :] vt = v[:, t, :] ww = kt + u p = torch.maximum(state_pp, ww) e1 = torch.exp(state_pp - p) e2 = torch.exp(ww - p) y[:, t, :] = (e1 * state_aa + e2 * vt) / (e1 * state_bb + e2) ww = w + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) state_aa = e1 * state_aa + e2 * vt state_bb = e1 * state_bb + e2 state_pp = p return y class TimeMixing(nn.Module): def __init__(self, n_embd, layer_id): super().__init__() self.n_embd = n_embd self.layer_id = layer_id self.time_decay = nn.Parameter(torch.ones(n_embd)) self.time_first = nn.Parameter(torch.ones(n_embd) * 0.5) self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) self.key = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(n_embd, n_embd, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.output = nn.Linear(n_embd, n_embd, bias=False) def forward(self, x): B, T, C = x.size() # Time mix Shift logic (Simplified for inference) xx = F.pad(x, (0, 0, 1, -1)) k = x * self.time_mix_k + xx * (1 - self.time_mix_k) v = x * self.time_mix_v + xx * (1 - self.time_mix_v) r = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(k) v = self.value(v) r = torch.sigmoid(self.receptance(r)) w = -torch.exp(self.time_decay) u = self.time_first return self.output(r * rwkv_linear_attention(B, T, C, r, k, v, w, u)) class ChannelMixing(nn.Module): def __init__(self, n_embd, layer_id): super().__init__() self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd)) self.key = nn.Linear(n_embd, n_embd * 4, bias=False) self.receptance = nn.Linear(n_embd, n_embd, bias=False) self.value = nn.Linear(n_embd * 4, n_embd, bias=False) def forward(self, x): xx = F.pad(x, (0, 0, 1, -1)) k = x * self.time_mix_k + xx * (1 - self.time_mix_k) r = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = torch.square(torch.relu(self.key(k))) r = torch.sigmoid(self.receptance(r)) return r * self.value(k) class StandardAttention(nn.Module): def __init__(self, n_embd): super().__init__() self.qkv = nn.Linear(n_embd, n_embd * 3) self.out_proj = nn.Linear(n_embd, n_embd) def forward(self, x): B, T, C = x.size() q, k, v = self.qkv(x).chunk(3, dim=-1) q = q.view(B, T, 1, C).transpose(1, 2) k = k.view(B, T, 1, C).transpose(1, 2) v = v.view(B, T, 1, C).transpose(1, 2) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).reshape(B, T, C) return self.out_proj(y) class Block(nn.Module): def __init__(self, n_embd, layer_id, is_attn=False): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) if is_attn: self.attn = StandardAttention(n_embd) self.ffn = nn.Sequential( nn.Linear(n_embd, n_embd * 4), nn.GELU(), nn.Linear(n_embd * 4, n_embd) ) else: self.att = TimeMixing(n_embd, layer_id) self.ffn = ChannelMixing(n_embd, layer_id) def forward(self, x): if hasattr(self, 'att'): x = x + self.att(self.ln1(x)) else: x = x + self.attn(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x class I3Model(nn.Module): def __init__(self, config): super().__init__() self.config = config self.embed = nn.Embedding(config['vocab_size'], config['n_embd']) # Build hybrid layers: RWKV first, then Attention self.layers = nn.ModuleList() for i in range(config['n_layer']): is_attn = i >= config['rwkv_layers'] self.layers.append(Block(config['n_embd'], i, is_attn=is_attn)) self.ln_f = nn.LayerNorm(config['n_embd']) self.head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False) def forward(self, idx): x = self.embed(idx) for layer in self.layers: x = layer(x) x = self.ln_f(x) return self.head(x) @torch.no_grad() def generate_stream(self, idx, max_new_tokens, temperature=1.0, top_p=0.9, eos_id=0): """Generator yielding token IDs one by one.""" for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= self.config['ctx_len'] else idx[:, -self.config['ctx_len']:] logits = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) # Nucleus Sampling sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[:, indices_to_remove] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) yield idx_next.item() if idx_next.item() == eos_id: break # ============================================================================ # 2. LOADING & INFERENCE SETUP # ============================================================================ REPO_ID = "i3-lab/i3-4096ctx-chat" device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(): print(f"Loading i3-Hybrid model from {REPO_ID}...") # Download essentials model_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json") with open(config_path, 'r') as f: raw_config = json.load(f) # Map configuration from fine-tune.py structure config = { 'n_embd': raw_config['d_model'], 'n_layer': raw_config['n_layers'], 'rwkv_layers': raw_config.get('rwkv_layers', 12), 'vocab_size': raw_config['vocab_size'], 'ctx_len': raw_config.get('inference_context_window', 4096), 'special_tokens': raw_config.get('special_tokens', {}) } tokenizer = Tokenizer.from_file(tokenizer_path) model = I3Model(config) # Load weights state_dict = torch.load(model_path, map_location=device) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] # strict=False allows ignoring auxiliary 'compressor' weights in the state_dict model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() return model, tokenizer, config try: model, tokenizer, config = load_model() # Special Tokens U_TOKEN = config['special_tokens'].get('user_token', '<|user|>') A_TOKEN = config['special_tokens'].get('assistant_token', '<|assistant|>') E_TOKEN = config['special_tokens'].get('eos_token', '') EOS_ID = tokenizer.token_to_id(E_TOKEN) or 0 print("✅ i3-Hybrid System Ready!") except Exception as e: print(f"❌ Initialization failed: {e}") model, tokenizer = None, None def chat_response(message, history, temperature, max_tokens, top_p): if model is None: yield "Error: Model failed to load. Check server logs." return # Format Chat ML style prompt prompt = "" for user_msg, assistant_msg in history: if user_msg and assistant_msg: prompt += f"{U_TOKEN}{user_msg}{E_TOKEN}{A_TOKEN}{assistant_msg}{E_TOKEN}" prompt += f"{U_TOKEN}{message}{E_TOKEN}{A_TOKEN}" input_ids = torch.tensor([tokenizer.encode(prompt).ids], device=device) generated_text = "" # Use the generator for streaming for token_id in model.generate_stream( input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, eos_id=EOS_ID ): word = tokenizer.decode([token_id]) generated_text += word yield generated_text.strip() # ============================================================================ # 3. GRADIO UI # ============================================================================ demo = gr.ChatInterface( fn=chat_response, title="i3-Hybrid Chat", description="Inference for the i3-4096ctx RWKV/Attention Hybrid model.", examples=[ ["Hello! Who are you?", 0.7, 512, 0.9], ["Write a Python script to sort a list.", 0.5, 1024, 0.9], ["Tell me a joke.", 1.0, 128, 0.8] ], additional_inputs=[ gr.Slider(0.1, 1.5, value=0.7, label="Temperature"), gr.Slider(64, 2048, value=512, step=64, label="Max New Tokens"), gr.Slider(0.5, 1.0, value=0.9, label="Top-P"), ], ) if __name__ == "__main__": demo.launch()