DQN Atari Pong / DQN 训练 Atari Pong
This project teaches value-based RL the hard way: delayed rewards, unstable bootstrap targets, and why replay buffers and target networks are absolutely necessary.
这个项目会让你真正理解 value-based RL 的难点,包括延迟奖励、bootstrap target 不稳定,以及 replay buffer 和 target network 为什么必要。
Project Background / 项目背景
DQN is one of the first deep reinforcement learning systems that clearly showed neural networks could learn control policies directly from pixels. Pong matters because it is simple enough to study but difficult enough to expose the instability of value learning.
DQN 是最早清楚证明“神经网络可以直接从像素学习控制策略”的深度强化学习系统之一。Pong 之所以重要,是因为它足够简单,适合研究;但又足够难,能真实暴露 value learning 的不稳定性。
Problem it solves / 它要解决什么问题
The problem is not just “play Pong.” The deeper problem is how to learn stable action values from high-dimensional visual input when rewards are delayed and each update changes the future target. DQN solves this with replay buffers, target networks, and convolutional Q estimation.
这个项目要解决的并不只是“玩 Pong”,更深层的问题是:当输入是高维图像、奖励是延迟到来的,而且每次更新都会改变未来目标时,如何稳定学出动作价值。DQN 给出的解法是 replay buffer、target network 和卷积式 Q 值估计。
What you learn / 你会学到什么
- ▸ Why replay buffers reduce temporal correlation / 为什么 replay buffer 可以降低时间相关性
- ▸ Why target networks stabilize Q-learning / 为什么 target network 能稳定 Q-learning
- ▸ How Atari preprocessing changes learning dynamics / Atari 预处理如何改变学习动态
- ▸ How to evaluate RL properly instead of cherry-picking rollouts / 如何正确评估 RL,而不是挑漂亮 rollout
Starter Code / 起始代码
class QNet(nn.Module):
def __init__(self, n_actions: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512), nn.ReLU(),
nn.Linear(512, n_actions),
)
def forward(self, x):
return self.net(x / 255.0)Code walkthrough / 代码要点解释
The CNN is not the hard part / CNN 不是最难的部分: most of the difficulty in DQN comes from unstable targets and correlated data, not from visual feature extraction itself. / DQN 真正的难点大多不在视觉特征提取,而在不稳定目标和强相关数据。
Replay changes the data distribution / replay 改变数据分布: without replay, sequential observations are too correlated and SGD becomes extremely noisy. / 没有 replay,连续 observation 相关性太强,SGD 会非常不稳定。
Target networks slow down feedback loops / target network 减缓反馈回路: bootstrapping on a network that is changing every step creates self-chasing dynamics. The target copy reduces that feedback loop. / 如果每一步都用快速变化的网络做 bootstrap,就会出现自我追逐。target copy 的存在就是为了减缓这个反馈回路。
The training step is the real engine / 真正的核心在训练步: sample a batch from replay, compute online Q-values, build the delayed target with the target network, and optimize TD error. If this path is wrong, the whole agent appears to learn but actually drifts. / 真正的引擎是训练步:从 replay 采样、计算 online Q 值、用 target network 构造延迟目标,再最小化 TD error。这个链条一旦写错,agent 看似在学,实际上会漂移。
Evaluation must be separated from exploration / 评估必须与探索解耦: if epsilon stays high during evaluation, you are mostly measuring noise rather than learned control. / 如果评估时 epsilon 仍然很高,你测到的基本只是噪声,而不是学到的控制能力。
Full runnable code / 完整可运行代码
A compact Atari DQN training script with replay buffer, target network, and evaluation loop. Save this as dqn_pong_train.py and install the listed dependencies for the project stack.
A compact Atari DQN training script with replay buffer, target network, and evaluation loop. 可将下面代码保存为 dqn_pong_train.py,并安装对应项目依赖后直接运行。
Dependencies / 依赖
- ▸ python>=3.10
- ▸ torch
- ▸ gymnasium[atari,accept-rom-license]
- ▸ ale-py
- ▸ numpy
Run commands / 运行命令
pip install torch numpy ale-py "gymnasium[atari,accept-rom-license]"
python dqn_pong_train.py
File tree / 目录结构
dqn-pong/
├── dqn_pong_train.py
├── replay/
│ └── buffer.cache
└── videos/
└── eval_episode.mp4import random
from collections import deque
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_env():
env = gym.make('ALE/Pong-v5', frameskip=4)
env = gym.wrappers.AtariPreprocessing(env, grayscale_obs=True, scale_obs=False)
env = gym.wrappers.FrameStackObservation(env, 4)
return env
class ReplayBuffer:
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
def push(self, *transition):
self.buffer.append(transition)
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
s, a, r, ns, d = zip(*batch)
return np.array(s), np.array(a), np.array(r), np.array(ns), np.array(d)
def __len__(self):
return len(self.buffer)
class QNet(nn.Module):
def __init__(self, n_actions):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4), nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512), nn.ReLU(),
nn.Linear(512, n_actions),
)
def forward(self, x):
return self.net(x.float() / 255.0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
env = make_env()
n_actions = env.action_space.n
q_net = QNet(n_actions).to(device)
target_net = QNet(n_actions).to(device)
target_net.load_state_dict(q_net.state_dict())
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-4)
replay = ReplayBuffer()
gamma = 0.99
batch_size = 32
warmup = 5000
target_update = 1000
epsilon_start, epsilon_end = 1.0, 0.05
obs, _ = env.reset(seed=42)
for step in range(20000):
epsilon = max(epsilon_end, epsilon_start - step / 100000)
if random.random() < epsilon:
action = env.action_space.sample()
else:
with torch.no_grad():
x = torch.tensor(np.array(obs), device=device).unsqueeze(0)
action = q_net(x).argmax(dim=1).item()
next_obs, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
replay.push(np.array(obs), action, reward, np.array(next_obs), done)
obs = next_obs
if done:
obs, _ = env.reset()
if len(replay) < warmup:
continue
states, actions, rewards, next_states, dones = replay.sample(batch_size)
states = torch.tensor(states, device=device)
actions = torch.tensor(actions, device=device).long()
rewards = torch.tensor(rewards, device=device).float()
next_states = torch.tensor(next_states, device=device)
dones = torch.tensor(dones, device=device).float()
q_values = q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
with torch.no_grad():
target_q = target_net(next_states).max(dim=1).values
targets = rewards + gamma * (1 - dones) * target_q
loss = F.smooth_l1_loss(q_values, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(q_net.parameters(), 10.0)
optimizer.step()
if step % target_update == 0:
target_net.load_state_dict(q_net.state_dict())
print(f'step={step} loss={loss.item():.4f} epsilon={epsilon:.3f}')
Build Steps / 构建步骤
Environment Setup / 环境准备
Install gymnasium[atari] and apply standard wrappers such as frame skip, grayscale warp, and frame stacking. / 安装 gymnasium[atari],并应用标准 Atari wrapper,比如 frame skip、灰度缩放和 frame stacking。
Replay Buffer / 回放缓存
Implement an efficient transition store and verify sample shapes before any long training run. / 实现高效 transition 存储,并在长时间训练前先验证采样 shape 是否正确。
Q-Network / Q 网络
Build the canonical convolutional Q-network on top of stacked frames. / 在堆叠帧输入上实现经典卷积 Q 网络。
Optimization / 优化
Use TD loss, target network syncing, epsilon-greedy exploration, and gradient clipping. / 使用 TD loss、target network 同步、epsilon-greedy 探索和梯度裁剪。
Evaluation / 评估
Track smoothed rewards and run real evaluation episodes with exploration disabled. / 跟踪平滑奖励,并在关闭探索后做真实评估回合。
Common Pitfalls / 常见坑
⚠️ Wrong preprocessing / 预处理错误
Atari preprocessing is part of the algorithm, not a cosmetic detail. If frame skip or resize differs, your results are not comparable. / Atari 预处理是算法的一部分,不是装饰细节。如果 frame skip 或 resize 不一致,结果就不可比。
⚠️ Target sync too aggressive / target 同步太激进
Updating the target network too often can destabilize learning and erase the point of a delayed bootstrap target. / target network 更新太频繁,会破坏延迟 bootstrap target 的意义,训练更不稳定。
⚠️ Overtrusting one episode / 过度相信单局表现
One good episode in Pong means almost nothing. You need repeated evaluation and smoothed curves. / Pong 里单局表现几乎没有意义,必须看重复评估和奖励曲线。
Success Criteria / 完成标准
- ✅ The agent consistently improves against random or weak policies / agent 对随机或弱策略能稳定提升
- ✅ Evaluation is done with exploration largely disabled / 评估时基本关闭探索
- ✅ You can explain why replay and target network are both needed / 你能解释 replay 和 target network 为什么缺一不可