i3-4096ctx-chat / app.py
FlameF0X's picture
Update app.py
461244d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import os
from tokenizers import Tokenizer
from typing import Optional, Tuple, List, Iterator
# ============================================================================
# 1. MODEL ARCHITECTURE (Exact match for i3HybridChatModel from fine-tune.py)
# ============================================================================
@torch.jit.script
def rwkv_linear_attention(B: int, T: int, C: int,
r: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor, u: torch.Tensor):
# Core RWKV-7 Linear Attention Kernel
y = torch.zeros_like(v)
state_aa = torch.zeros(B, C, device=v.device)
state_bb = torch.zeros(B, C, device=v.device)
state_pp = torch.ones(B, C, device=v.device) * -1e30
for t in range(T):
rt = r[:, t, :]
kt = k[:, t, :]
vt = v[:, t, :]
ww = kt + u
p = torch.maximum(state_pp, ww)
e1 = torch.exp(state_pp - p)
e2 = torch.exp(ww - p)
y[:, t, :] = (e1 * state_aa + e2 * vt) / (e1 * state_bb + e2)
ww = w + state_pp
p = torch.maximum(ww, kt)
e1 = torch.exp(ww - p)
e2 = torch.exp(kt - p)
state_aa = e1 * state_aa + e2 * vt
state_bb = e1 * state_bb + e2
state_pp = p
return y
class TimeMixing(nn.Module):
def __init__(self, n_embd, layer_id):
super().__init__()
self.n_embd = n_embd
self.layer_id = layer_id
self.time_decay = nn.Parameter(torch.ones(n_embd))
self.time_first = nn.Parameter(torch.ones(n_embd) * 0.5)
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_v = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
self.key = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd, n_embd, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.output = nn.Linear(n_embd, n_embd, bias=False)
def forward(self, x):
B, T, C = x.size()
# Time mix Shift logic (Simplified for inference)
xx = F.pad(x, (0, 0, 1, -1))
k = x * self.time_mix_k + xx * (1 - self.time_mix_k)
v = x * self.time_mix_v + xx * (1 - self.time_mix_v)
r = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(k)
v = self.value(v)
r = torch.sigmoid(self.receptance(r))
w = -torch.exp(self.time_decay)
u = self.time_first
return self.output(r * rwkv_linear_attention(B, T, C, r, k, v, w, u))
class ChannelMixing(nn.Module):
def __init__(self, n_embd, layer_id):
super().__init__()
self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
self.key = nn.Linear(n_embd, n_embd * 4, bias=False)
self.receptance = nn.Linear(n_embd, n_embd, bias=False)
self.value = nn.Linear(n_embd * 4, n_embd, bias=False)
def forward(self, x):
xx = F.pad(x, (0, 0, 1, -1))
k = x * self.time_mix_k + xx * (1 - self.time_mix_k)
r = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = torch.square(torch.relu(self.key(k)))
r = torch.sigmoid(self.receptance(r))
return r * self.value(k)
class StandardAttention(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.qkv = nn.Linear(n_embd, n_embd * 3)
self.out_proj = nn.Linear(n_embd, n_embd)
def forward(self, x):
B, T, C = x.size()
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, 1, C).transpose(1, 2)
k = k.view(B, T, 1, C).transpose(1, 2)
v = v.view(B, T, 1, C).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).reshape(B, T, C)
return self.out_proj(y)
class Block(nn.Module):
def __init__(self, n_embd, layer_id, is_attn=False):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
if is_attn:
self.attn = StandardAttention(n_embd)
self.ffn = nn.Sequential(
nn.Linear(n_embd, n_embd * 4),
nn.GELU(),
nn.Linear(n_embd * 4, n_embd)
)
else:
self.att = TimeMixing(n_embd, layer_id)
self.ffn = ChannelMixing(n_embd, layer_id)
def forward(self, x):
if hasattr(self, 'att'):
x = x + self.att(self.ln1(x))
else:
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class I3Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed = nn.Embedding(config['vocab_size'], config['n_embd'])
# Build hybrid layers: RWKV first, then Attention
self.layers = nn.ModuleList()
for i in range(config['n_layer']):
is_attn = i >= config['rwkv_layers']
self.layers.append(Block(config['n_embd'], i, is_attn=is_attn))
self.ln_f = nn.LayerNorm(config['n_embd'])
self.head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)
def forward(self, idx):
x = self.embed(idx)
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
return self.head(x)
@torch.no_grad()
def generate_stream(self, idx, max_new_tokens, temperature=1.0, top_p=0.9, eos_id=0):
"""Generator yielding token IDs one by one."""
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config['ctx_len'] else idx[:, -self.config['ctx_len']:]
logits = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
# Nucleus Sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
yield idx_next.item()
if idx_next.item() == eos_id:
break
# ============================================================================
# 2. LOADING & INFERENCE SETUP
# ============================================================================
REPO_ID = "i3-lab/i3-4096ctx-chat"
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
print(f"Loading i3-Hybrid model from {REPO_ID}...")
# Download essentials
model_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
config_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=REPO_ID, filename="tokenizer.json")
with open(config_path, 'r') as f:
raw_config = json.load(f)
# Map configuration from fine-tune.py structure
config = {
'n_embd': raw_config['d_model'],
'n_layer': raw_config['n_layers'],
'rwkv_layers': raw_config.get('rwkv_layers', 12),
'vocab_size': raw_config['vocab_size'],
'ctx_len': raw_config.get('inference_context_window', 4096),
'special_tokens': raw_config.get('special_tokens', {})
}
tokenizer = Tokenizer.from_file(tokenizer_path)
model = I3Model(config)
# Load weights
state_dict = torch.load(model_path, map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
# strict=False allows ignoring auxiliary 'compressor' weights in the state_dict
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
return model, tokenizer, config
try:
model, tokenizer, config = load_model()
# Special Tokens
U_TOKEN = config['special_tokens'].get('user_token', '<|user|>')
A_TOKEN = config['special_tokens'].get('assistant_token', '<|assistant|>')
E_TOKEN = config['special_tokens'].get('eos_token', '<EOS>')
EOS_ID = tokenizer.token_to_id(E_TOKEN) or 0
print("✅ i3-Hybrid System Ready!")
except Exception as e:
print(f"❌ Initialization failed: {e}")
model, tokenizer = None, None
def chat_response(message, history, temperature, max_tokens, top_p):
if model is None:
yield "Error: Model failed to load. Check server logs."
return
# Format Chat ML style prompt
prompt = ""
for user_msg, assistant_msg in history:
if user_msg and assistant_msg:
prompt += f"{U_TOKEN}{user_msg}{E_TOKEN}{A_TOKEN}{assistant_msg}{E_TOKEN}"
prompt += f"{U_TOKEN}{message}{E_TOKEN}{A_TOKEN}"
input_ids = torch.tensor([tokenizer.encode(prompt).ids], device=device)
generated_text = ""
# Use the generator for streaming
for token_id in model.generate_stream(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
eos_id=EOS_ID
):
word = tokenizer.decode([token_id])
generated_text += word
yield generated_text.strip()
# ============================================================================
# 3. GRADIO UI
# ============================================================================
demo = gr.ChatInterface(
fn=chat_response,
title="i3-Hybrid Chat",
description="Inference for the i3-4096ctx RWKV/Attention Hybrid model.",
examples=[
["Hello! Who are you?", 0.7, 512, 0.9],
["Write a Python script to sort a list.", 0.5, 1024, 0.9],
["Tell me a joke.", 1.0, 128, 0.8]
],
additional_inputs=[
gr.Slider(0.1, 1.5, value=0.7, label="Temperature"),
gr.Slider(64, 2048, value=512, step=64, label="Max New Tokens"),
gr.Slider(0.5, 1.0, value=0.9, label="Top-P"),
],
)
if __name__ == "__main__":
demo.launch()