🧠
TransformerIntermediate

Mini-GPT (Shakespeare)

This project turns transformer theory into a working language model. You will build a small GPT from scratch, train it on Shakespeare, and learn how architecture, optimization, and sampling decisions change generation quality.这个项目会把 transformer 理论真正落成一个可运行语言模型。你会从零构建一个小型 GPT,在 Shakespeare 语料上训练,并观察架构、优化和采样策略如何影响生成质量。

Dataset / 数据集
Tiny Shakespeare / Tiny Shakespeare
Model / 模型
Decoder-only GPT / Decoder-only GPT
Target / 目标
Coherent samples / 生成连贯文本
Train time / 训练时间
~10–90 min / 约 10–90 分钟

Project Background / 项目背景

Mini-GPT is the bridge between reading transformer theory and actually owning the mechanics of autoregressive language modeling. Instead of treating GPT as a black box API, this project makes you build the moving parts yourself and see how modern LLM behavior emerges from simple components.
Mini-GPT 是连接“读懂 transformer 理论”和“真正掌握自回归语言模型机制”的桥梁。它不是把 GPT 当成黑盒 API 来用,而是让你亲手搭出核心部件,看到现代 LLM 的行为是如何从一组并不神秘的模块里涌现出来的。

Problem it solves / 它要解决什么问题

The real problem is not “generate Shakespeare text.” The real problem is understanding how a model turns discrete tokens into contextual representations, prevents future leakage, and learns next-token prediction stably enough to produce coherent samples. This page turns that abstract systems question into something you can inspect line by line.
这个项目真正要解决的并不是“生成莎士比亚文本”,而是理解模型如何把离散 token 变成上下文化表示,如何阻止未来信息泄漏,以及如何稳定地学会 next-token prediction,最终生成连贯文本。它把一个抽象的系统问题,变成你可以逐行检查的实现。

What you learn / 你会学到什么

  • How token embeddings and positional embeddings work together
    token embedding 和 positional embedding 如何配合
  • How causal masking prevents information leakage from future tokens
    causal masking 如何阻止未来信息泄漏
  • How transformer blocks stack attention and feed-forward computation
    transformer block 如何堆叠 attention 与前馈计算
  • How sampling strategy changes the quality and diversity of generated text
    采样策略如何改变文本质量与多样性
  • How to reason about context length, perplexity, and training stability
    如何理解 context length、perplexity 和训练稳定性

Starter Architecture Sketch / Starter Architecture Sketch

This sketch shows the minimum architectural skeleton. The goal is not to copy blindly. The goal is to understand why embeddings, masking, residual paths, and logits all connect the way they do.
这段代码给出最小架构骨架。目标不是盲抄,而是真正理解 embedding、mask、residual path 和 logits 为什么要这样连接。

import torch
import torch.nn as nn

class MiniGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(...)
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        return self.head(x)

Code walkthrough / 代码要点解释

Embedding tables are learned lookup systems

The token embedding turns discrete ids into dense vectors. Positional embeddings add order information so the model can distinguish “A then B” from “B then A”.

Embedding 表本质上是可学习查表系统

token embedding 会把离散 id 映射成稠密向量。位置 embedding 再补上顺序信息,让模型知道“A 在 B 前面”和“B 在 A 前面”不是一回事。

Masked attention is the heart of autoregression

The lower-triangular mask forces each token to see only the past. If this mask is wrong, the model cheats during training and collapses at generation time.

带 mask 的 attention 是自回归建模核心

下三角 mask 强制每个 token 只能看见过去。如果 mask 写错,模型训练时就会作弊,到了生成时会明显崩掉。

Residual structure makes depth trainable

Each transformer block adds attention and feed-forward updates on top of the previous representation. Residual paths and normalization are what make deeper models train stably.

Residual 结构让深层模型可训练

每个 transformer block 都是在上一层表示的基础上再加 attention 和前馈更新。真正让深层训练稳定的关键是 residual path 和 normalization。

The final linear head projects hidden states back to vocabulary logits

This is where the model turns representation space into next-token probabilities. Training with cross-entropy teaches the model to push the right token above the rest.

最终线性头把隐藏状态投回词表 logits

这一步把表示空间重新映射成下一个 token 的概率分布。交叉熵训练的目标,就是让正确 token 的 logit 被持续推高。

The forward pass defines the whole learning contract / forward 过程定义了整个学习契约

The sequence idx enters token embeddings, receives positional information, flows through masked transformer blocks, and is finally projected into logits over the vocabulary. Every tensor shape in this path matters, because one silent mismatch is enough to make training unstable or generation nonsensical.
序列 idx 先进入 token embedding,再加入位置信息,流过带 mask 的 transformer block,最后投影回词表 logits。这个路径里每一个 tensor shape 都非常关键,因为任何一次静默 mismatch 都足以让训练变得不稳定,或者让生成结果失去意义。

Training and sampling are two different regimes / 训练和采样是两种不同制度

During training, the model sees many teacher-forced prefixes and learns a probability distribution. During sampling, it must live with its own previous outputs. Understanding this train-infer gap is essential if you want to reason about why samples degrade, repeat, or drift off-topic.
训练时,模型会在 teacher forcing 条件下看到大量前缀并学习概率分布;采样时,它必须和自己的历史输出共处。理解这种 train-infer gap,是你分析生成退化、重复或跑题的关键。

Full runnable code / 完整可运行代码

A compact PyTorch script that trains a tiny character-level GPT on a local text file. Save this as mini_gpt_train.py and install the listed dependencies for the project stack.
A compact PyTorch script that trains a tiny character-level GPT on a local text file. 可将下面代码保存为 mini_gpt_train.py,并安装对应项目依赖后直接运行。

Dependencies / 依赖

  • python>=3.10
  • torch
  • input.txt (Tiny Shakespeare or another text corpus)

Run commands / 运行命令

pip install torch
python mini_gpt_train.py

File tree / 目录结构

mini-gpt/
├── input.txt
├── mini_gpt_train.py
└── outputs/
    └── sample.txt
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

text = open("input.txt", "r", encoding="utf-8").read()
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda xs: ''.join(itos[i] for i in xs)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 128
batch_size = 32
n_embd = 192
n_head = 6
n_layer = 6
dropout = 0.1
max_iters = 2000
eval_interval = 200
lr = 3e-4

ids = torch.tensor(encode(text), dtype=torch.long)
split = int(0.9 * len(ids))
train_ids, val_ids = ids[:split], ids[split:]


def get_batch(split_name: str):
    data = train_ids if split_name == 'train' else val_ids
    ix = torch.randint(len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)


class Head(nn.Module):
    def __init__(self, head_size: int):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * (k.size(-1) ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = self.dropout(F.softmax(wei, dim=-1))
        v = self.value(x)
        return wei @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, head_size: int):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class MiniGPT(nn.Module):
    def __init__(self):
        super().__init__()
        vocab_size = len(chars)
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B * T, C), targets.view(B * T))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens: int):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


model = MiniGPT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for step in range(max_iters):
    xb, yb = get_batch('train')
    _, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % eval_interval == 0:
        model.eval()
        vx, vy = get_batch('val')
        _, vloss = model(vx, vy)
        print(f"step={step} train_loss={loss.item():.4f} val_loss={vloss.item():.4f}")
        model.train()

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, 300)[0].tolist()))

Build Steps / 构建步骤

1

Tokenizer / Tokenizer

Character-level tokenizer: build vocab, encode/decode strings, handle special tokens.
字符级 tokenizer,构建词表、实现 encode/decode,并处理特殊 token。

2

Embedding Layer / Embedding 层

Token + positional embeddings. Understand why we need both.
实现 token embedding 和 positional embedding,并理解两者都不可缺。

3

Attention Block / Attention Block

Causal multi-head attention with explicit masking and unit tests.
实现带显式 mask 和单元测试的 causal 多头注意力。

4

Transformer Block / Transformer Block

LayerNorm → Attention → residual → LayerNorm → FFN → residual.
LayerNorm → Attention → residual → LayerNorm → FFN → residual。

5

GPT Model / GPT 模型

Stack N blocks, add final norm and linear head, then count parameters.
堆叠 N 个 block,加入最终 norm 和线性头,并统计参数量。

6

Training / 训练

Use cross-entropy, AdamW, and cosine decay to train next-token prediction.
用交叉熵、AdamW 和 cosine decay 训练 next-token prediction。

7

Sampling / 采样

Compare greedy, top-k, top-p, and temperature-based decoding.
对比 greedy、top-k、top-p 和 temperature 解码。

8

Stretch Goals / 扩展目标

Try Flash Attention, longer context, DDP, or a better tokenizer.
尝试 Flash Attention、更长上下文、DDP 或更好的 tokenizer。

Common Pitfalls / 常见坑

⚠️ Wrong mask convention / mask 方向写错

One sign error in the causal mask lets tokens see the future. Unit-test on a 4-token example before trusting the loss.
causal mask 只要符号方向错一次,token 就会偷看未来。先在 4-token 的极小例子上做单元测试,再相信 loss。

⚠️ Forgetting to shift targets / 忘记错位标签

Language modeling predicts the next token. Inputs and labels must be offset by one position.
语言模型预测的是“下一个 token”,所以输入和标签必须错开一位。

⚠️ Sampling with dropout on / 采样时忘记关闭 dropout

If model.eval() is missing, output quality becomes noisy and unstable during generation.
如果忘记 model.eval(),生成结果会因为 dropout 仍然开启而变得噪声很大。

⚠️ Ignoring gradient clipping / 忽略梯度裁剪

Even small GPTs can spike gradients early in training. Clip and monitor loss explosions.
即使是小 GPT,在训练早期也会出现梯度尖峰。要做裁剪,并监控 loss 爆炸。

Hardware Comparison / 硬件对比

HardwareBase trainLong contextSweep
MacBook M4 Pro✅ Small run⚠️ Limited❌ Painful
RTX 4090✅ Comfortable✅ Good✅ Strong
A100 80GB✅ Easy✅ Excellent✅ Excellent

Grading Rubric / 完成标准

  • ✅ Model trains without mask bugs or divergence / 模型训练无 mask bug、无明显发散
  • ✅ Generation quality improves across checkpoints / 生成质量会随着 checkpoint 提升
  • ✅ Sampling script supports temperature and top-k / 采样脚本支持 temperature 和 top-k
  • ✅ You can explain why each architecture piece exists / 你能解释每个架构部件存在的原因

References / 参考资料