Research Article Template
A modern, interactive template for scientific writing that brings papers to life with web-native features. The web offers what static PDFs can't: interactive diagrams, progressive notation, and exploratory views that show how ideas behave. This template treats interactive artifacts—figures, math, code, and inspectable experiments—as first-class alongside prose, helping readers build intuition instead of skimming results—all with minimal setup and no web knowledge required.
Porting nanochat to Transformers: an AI modeling history lesson
tldr: There is a lot t learn about ML from nanochat, and even more to learn about the history of the transformer architecture.
Recently I was working on helping students of the nanochat project to share their models and discuss their learning on Hugging Face. In the process, I thought it would be useful if the model was integrated into the transformers library. This would allow others to use their nanochat models for inference in loads of downstream libraries like vLLM for inference or TRL for post-training.
You can now use nanochat models in transformers and tap into all those educational gains across the ecosystem. But along the way, I uncovered a further treasure trove of education about how canonical models relate to each other, and the components they share.
I received the lesson from the simple teacher of class inheritance and transformers modular philosophy. If you want to learn more about that, check out this guide here.
Here, let’s tuck into this deep dive on how NanoChat relates the lineage of transformer architectures.
What is nanochat?
On October 13th 2025, Andrej Karpathy unceremoniously dropped the nanochat repo into the unsuspecting AI world. To hype seekers, this was just a small and pretty average LLM. To ML devotees, this was nirvana. A raw unadulterated chance to tinker, fiddle, and play with a transformer model defined in pure pytorch. Nothing was hidden away in fancy torch methods or inherited from complex class structures. It was all there in a simple file.
![][image1]
Karpathy had painstakingly implemented an end-to-end build of an LLM system without the use of most major libraries. Even though in real world situations most rely on transformers, tokenizers, datasets, trl, etc. This back to basics approach gives us the chance to genuinely learn and understand something from the ground up.
Personally, I found the process to be one of the most educational I can remember.
What is transformers and how is it educational?
Most of know the transformers library as the backbone of modern machine learning, but if we dig a little deeper, it’s a powerful piece of education.
If you don’t know… transformers is the de facto implementation of modern AI models that bear the same name; ‘transformers’ like models in GPT, DeepSeek, Claude, series. transformers is a special project because it contains the implementation of all major open model architecture and those model architectures are modularized to reuse functionality from each other. If you want to explore the philosophy and lineage behind transformers’ modularity, check out this guide here.
In general, scientists at AI research labs design, implement, and train their models in their framework of choice, be that torch, JAX, etc. When they come to share their open model with the community, they will open a PR on transformers and refactor their code to use relevant modules.
Because transformers contain most major model implementations, researchers have to inherent model architecture attributes from other canonical models. This is in every sense a ‘single source of truth’.
This practical feature of the library has an amazingly educational quality to it. We can read a model implementation as a series of references to other usages of those architectural features. For example, when one model uses a certain type of RMSNorm, we can plainly see that it is the same implementation as another model because it inherits that class entirely. For example, check out nanochat’s RMSNorm:
class NanoChatRMSNorm(Llama4TextL2Norm):
pass
The transformers library then converts the modular_* implementation into a modeling_* implementation, which contains the complete torch native implementation:
class NanoChatRMSNorm(torch.nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def extra_repr(self):
return f"eps={self.eps}"
If we review a model in transformers, we can review both sides and learn from the math and literature of the model’s implementation. Due to the educational nature of nanochat, I thought that it was a perfect opportunity to explore this aspect of transformers and share what I learnt with students.
Why do we need nanochat in transformers?
It might seem counterintuitive to support an educational model like nanochat in a production grade library like transformers. After all, we can see from nanochat’s benchmark scores that it does not rival state of the art models like Qwen3, SmolLM3, Gemma3, or Olmo3.
Nanochat was never really intended as a production grade model. It was meant as an educational tool, and that’s the same reason why we need it in transformers. There are four main reasons:
transformersas a single source of truth teaches us aboutnanochat’s lineage.- use the
nanochatmodel in other libraries. - save money by reusing nanochat checkpoints for fine-tuning.
- compare nanochat fine-tuning with other open model checkpoints.
Firstly, as mentioned abovetransformers teaches us about the modeling conventions that Karpathy uses from other canonical implementations.
Secondly, because transformers is a standard within the ecosystem, it unlocks more downstream learning in post training libraries, quantisation tools, inference libraries, and device integrations. In practical terms, here are some examples nanochat students could learn on top of transformers:
- Quantize models in llama.cpp ($0)
- Integrate models into the browser and WebGPU ($0)
- SFT training in TRL/torch on Google Colab ($0)
- RL training TRL/torch on Google Colab ($0 - $9)
- Agentic RL in TRL on Google Colab ($0 - $9)
Finally, training AI models is expensive. Running the nanochat speedrun.sh costs between $200 and $2k depending on the model size we use. Which is little compared to the millions of dollars invested by frontier labs. But that is still a significant sum for students, who always learn best by taking a few chances to fail and build experience.
In short, let’s unlock more opportunities for education!
The nanochat architecture
As described by Karpathy, nanochat uses an archetypal architecture that is common across the field, which makes it an excellent choice for an educational resource because folk get to learn from what works.
The core model implementation (nanochat/gpt.py, 291 lines) demonstrates modern transformer architecture, with every design decision documented and justified.
The configuration uses a single complexity slider: depth. Set --depth=20 and everything else automatically adjusts. Model dimension equals depth × 64 (20 layers → 1,280 dimensions). Number of attention heads equals depth ÷ 2 (10 heads). Head dimension is fixed at 128. This "aspect ratio philosophy" simplifies scaling. So if you want a more capable model or have a bigger budget. Just increase depth to 26 ($300 budget) or 30 ($1,000 budget).
The architecture incorporates five key improvements over vanilla transformers. Let’s work through the components of this architecture and compare them across implementation:
Forward pass based on the Llama Architecture
The forward pass in nanochat handles both training and generation. We can simply read that the input x is embedded and then updated by each layer then the head. During training, a loss is calculated and returned instead of the logits themselves.
def forward(self, x, targets=None, loss_reduction='mean'):
x = self.token_emb(x)
for layer in self.layers:
x = layer(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, self.vocab_size),
targets.view(-1),
ignore_index=-1,
reduction=loss_reduction
)
return loss
return logits
By returning loss directly when targets are provided, the training loop becomes trivial. No separate loss computation, no manual masking logic—just loss = model(inputs, targets) followed by loss.backward().
transformers has to make things a bit more complex to facilitate the downstream ecosystem that uses logits in a broad spectrum of ways. Therefore, loss calculation is dealt with in training-specific code, and the forward function returns BaseModelOutputWithPast.
class NanoChatModel(LlamaModel):
def __init__(self, config: NanoChatConfig):
super().__init__(config)
self.initial_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
hidden_states = self.initial_norm(hidden_states) # Additional norm before the layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
Rotary Position Embeddings (RoPE)
Rotary Position Embeddings (RoPE) replace learned positional encodings by rotating query and key vectors using precomputed sin/cos frequencies:
def apply_rope(x, cos, sin):
x1, x2 = x[..., ::2], x[..., 1::2]
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
return torch.stack([y1, y2], dim=-1).flatten(-2)
In transformers, the rotary embeddings are implemented like so:
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
pass
def rotate_half(x):
"""Rotates half the hidden dims of the input with flipped signs for NanoChat."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((x2, -x1), dim=-1)
NanoChatRotaryEmbedding almost entirely inherits from the original Llama series, except for a sign inversion in rotate_half.
QK Normalization
NanoChat applies RMSNorm to queries and keys before computing attention to stabilize training.
In the original gpt.py, this is achieved via a functional norm helper applied directly inside the attention forward pass:
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
class CausalSelfAttention(nn.Module):
...
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
...
In the modular transformers implementation, we see a fascinating mix of lineages. The NanoChatRMSNorm inherits directly from Llama4TextL2Norm, while the attention mechanism inherits from Qwen3Attention. We simply inject the QK normalization into the Qwen3 logic:
class NanoChatRMSNorm(Llama4TextL2Norm):
pass
class NanoChatAttention(Qwen3Attention):
def __init__(self, config: NanoChatConfig, layer_idx: int):
super().__init__(config, layer_idx)
del self.sliding_window
del self.layer_type
self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# RoPE -> Norm (instead of usual Norm -> RoPE)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
Untied Weights
Karpathy's implementation deliberately unties the weights between the token embedding and the language model head to provide the model with more flexibility. In gpt.py, these are initialized as two completely separate modules:
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# ... (rest of init)
In the modular implementation, we inherit from Gemma2ForCausalLM. This is a powerful simplification—Gemma 2 also supports untied weights and advanced output structures. By simply inheriting the class, we pull in all the necessary machinery for causal generation, while the configuration object (defined elsewhere) ensures the weights remain untied:
class NanoChatForCausalLM(Gemma2ForCausalLM):
def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
super().forward(**super_kwargs)
ReLU² Activation
The original implementation replaces the standard GELU activation with ReLU², which is simply ReLU squared. This provides a faster alternative without performance loss. In gpt.py, this is hardcoded into the MLP block:
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
In the modular file, we see another surprising inheritance: CLIPMLP. The CLIP architecture uses a structure that fits our needs perfectly, so we inherit the structural definition from CLIP and let the configuration drive the specific activation function (ReLU2):
class NanoChatMLP(CLIPMLP):
def __init__(self, config):
super().__init__(config)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
Multi-Query Attention (MQA)
NanoChat uses Multi-Query Attention (MQA) to reduce the memory footprint of the KV cache, using 10 query heads but only 4 key/value heads (in the default config).
In gpt.py, this logic is handled by passing distinct head counts and relying on PyTorch's functional attention to handle the broadcasting (or explicitly handling it during inference):
class CausalSelfAttention(nn.Module):
# ...
def forward(self, x, cos_sin, kv_cache):
# ...
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1:
# During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else:
# During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix)
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
prefix_len = Tk - Tq
if prefix_len > 0: # can't be negative but could be zero
attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# ...
In modular_nanochat.py, we don't need to write this logic at all. As seen in the QK Normalization section above, NanoChatAttention inherits from Qwen3Attention. The Qwen3 implementation is robust and fully supports GQA/MQA out of the box. By using this parent class, we get production-grade attention implementation "for free," allowing us to focus solely on the unique normalizations required by NanoChat.
Conclusion
It’s very clear that Andrej Karpathy’s implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
Use Nanochat in Transformers
If you’d like to try out your own nanochat models in transformers
- Download the nanochat-d34 checkpoint
hf download karpathy/nanochat-d34 --local-dir nanochat-d34
- Convert the checkpoint to transformers format
uv run \
--with "transformers @ git+https://github.com/huggingface/transformers.git@nanochat-implementation" \
--with "tiktoken>=0.12.0" \
https://raw.githubusercontent.com/huggingface/transformers/nanochat-implementation/src/transformers/models/nanochat/convert_nanochat_checkpoints.py \
--input_dir ./nanochat-d34 \
--output_dir ./nanochat-d3-hf
- (optional) Upload the checkpoint to the Hugging Face Hub
hf upload <username>/nanochat-d34 nanochat-d34
- Test the model
import torch
from transformers import AutoTokenizer, NanoChatForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./nanochat-d3-hf")
model = NanoChatForCausalLM.from_pretrained("./nanochat-d3-hf")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
prompt = "Hello, how are you?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.pop("token_type_ids", None)
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Notebooks
If you want to train with these models, you can use these colab notebooks: