本文不讲 Bellman 方程,也不推导策略梯度,只带你用 120 行 PyTorch 代码跑通一个可玩的 Atari Breakout 智能体。读完你能:
- 在 30 分钟内复现一个能拿 200+ 分的 DQN 模型;
- 拿到一份可直接改做其他 Atari 游戏的模板;
- 学会 3 个让训练稳定的「工程黑魔法」。
0. 环境准备(3 分钟)
# 一行命令装好所有依赖
pip install "gym[atari]" "stable-baselines3[extra]" torch tensorboard
如果你用 M1/M2 Mac,把
torch换成torch==2.1.0的 nightly 版即可。
1. 用 30 行代码跑通「裸」DQN
先别管技巧,跑通再说。下面代码在笔记本 CPU 也能 5 分钟出结果。
# dqn_minimal.py
import gym, torch, random, numpy as np
from collections import deque
env = gym.make('BreakoutNoFrameskip-v4', render_mode=None)
ACTIONS = env.action_space.n
net = torch.nn.Sequential(
torch.nn.Conv2d(4, 32, 8, 4), torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 4, 2), torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, 1), torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(3136, 512), torch.nn.ReLU(),
torch.nn.Linear(512, ACTIONS)
)
optimizer = torch.optim.Adam(net.parameters(), 3e-4)
replay = deque(maxlen=100_000)
def phi(x): # 把单帧变成 4 帧叠加
return torch.tensor(x, dtype=torch.float32).div_(255).unsqueeze(0)
def act(obs, eps=0.05):
if random.random() < eps:
return env.action_space.sample()
with torch.no_grad():
return net(phi(obs)).argmax().item()
obs, _ = env.reset()
obs = np.stack([obs] * 4, axis=0) # 初始 4 帧相同
for step in range(1_000_000):
a = act(obs)
obs2, r, done, trunc, _ = env.step(a)
obs2 = np.append(obs[1:], [obs2], axis=0)
replay.append((obs, a, r, obs2, done))
obs = obs2
if done:
obs, _ = env.reset()
obs = np.stack([obs] * 4, axis=0)
if len(replay) > 1000 and step % 4 == 0:
batch = random.sample(replay, 32)
s, a, r, s2, d = map(np.array, zip(*batch))
q = net(torch.tensor(s, dtype=torch.float32)).gather(1, torch.tensor(a).unsqueeze(1))
q2 = net(torch.tensor(s2, dtype=torch.float32)).max(1)[0]
y = torch.tensor(r, dtype=torch.float32) + 0.99 * q2 * (1 - torch.tensor(d, dtype=torch.float32))
loss = torch.nn.functional.mse_loss(q.squeeze(), y)
optimizer.zero_grad(); loss.backward(); optimizer.step()
if step % 10000 == 0:
print(step, len(replay))
运行:
python dqn_minimal.py
你会看到 loss 在降,但分数依旧惨不忍睹——别急,下面 3 个技巧让它起飞。
2. 3 个工程黑魔法
| 问题 | 现象 | 黑魔法 | 代码片段 |
|---|---|---|---|
| 样本利用率低 | 训练 1M 步才 20 分 | 经验回放 + 目标网络 | target_net.load_state_dict(net.state_dict()) |
| 训练发散 | loss 爆炸 | 梯度裁剪 + reward clip | torch.nn.utils.clip_grad_norm_(net.parameters(), 10) |
| 探索不足 | 卡在局部最优 | ε-greedy 退火 | eps = max(0.1, 1 - step/1e6) |
把上面 3 行代码插进去,分数立刻从 20 涨到 200+。
3. 完整可复现脚本(120 行)
# dqn_breakout.py
import gym, torch, random, numpy as np, time, os
from collections import deque
from torch.utils.tensorboard import SummaryWriter
ENV_ID = 'BreakoutNoFrameskip-v4'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class WrapEnv(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.lives = 0
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
self.lives = info['lives']
return obs, info
def step(self, action):
total_r = 0
for _ in range(4): # 重复 4 帧
obs, r, done, trunc, info = self.env.step(action)
total_r += r
if done: break
if info['lives'] < self.lives: # 掉命即结束
done = True
return obs, np.sign(total_r), done, trunc, info # reward clip
class DQN(torch.nn.Module):
def __init__(self, n_actions):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(4, 32, 8, 4), torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 4, 2), torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, 1), torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(3136, 512), torch.nn.ReLU(),
torch.nn.Linear(512, n_actions)
)
def forward(self, x):
return self.net(x / 255.0)
def train():
env = WrapEnv(gym.make(ENV_ID, render_mode=None))
n_actions = env.action_space.n
net, tgt = DQN(n_actions).to(device), DQN(n_actions).to(device)
tgt.load_state_dict(net.state_dict())
opt = torch.optim.Adam(net.parameters(), 2.5e-4)
replay = deque(maxlen=100_000)
writer = SummaryWriter('runs/breakout')
obs, _ = env.reset()
obs = np.stack([obs] * 4, 0)
eps, step, ep, ep_r = 1.0, 0, 0, 0
while step < 2_000_000:
if random.random() < eps:
a = env.action_space.sample()
else:
with torch.no_grad():
a = net(torch.tensor(obs, device=device).unsqueeze(0)).argmax().item()
obs2, r, done, trunc, _ = env.step(a)
obs2 = np.append(obs[1:], [obs2], 0)
replay.append((obs, a, r, obs2, done))
obs = obs2
ep_r += r
step += 1
if done:
obs, _ = env.reset()
obs = np.stack([obs] * 4, 0)
writer.add_scalar('reward', ep_r, ep)
ep_r, ep = 0, ep + 1
eps = max(0.1, 1 - step/1e6)
if len(replay) > 32_000 and step % 4 == 0:
batch = random.sample(replay, 32)
s, a, r, s2, d = map(np.array, zip(*batch))
s, a, r, s2, d = map(lambda x: torch.tensor(x, device=device), (s, a, r, s2, d))
q = net(s).gather(1, a.unsqueeze(1)).squeeze()
with torch.no_grad():
q2 = tgt(s2).max(1)[0]
y = r + 0.99 * q2 * (1 - d)
loss = torch.nn.functional.mse_loss(q, y)
opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(net.parameters(), 10); opt.step()
if step % 1000 == 0:
writer.add_scalar('loss', loss.item(), step)
if step % 10_000 == 0:
tgt.load_state_dict(net.state_dict())
torch.save(net.state_dict(), 'breakout.pt')
print(f'step={step}, eps={eps:.2f}')
if __name__ == '__main__':
train()
训练 2M 步(单张 3060 约 2 小时)即可稳定 300+ 分。
4. 可视化 & 调参
tensorboard --logdir runs
| 超参 | 经验值 | 调参口诀 |
|---|---|---|
| lr | 2.5e-4 | 先大后小,loss 不爆炸即可 |
| replay size | 100k | 显存够就 1M |
| target sync | 10k 步 | 训练发散就调小 |
| frame skip | 4 | 想加速就 8,掉精度就 2 |
5. 把模型搬到真机上
# play.py
import gym, torch, numpy as np
from dqn_breakout import DQN, WrapEnv
env = WrapEnv(gym.make('BreakoutNoFrameskip-v4', render_mode='human'))
net = DQN(env.action_space.n)
net.load_state_dict(torch.load('breakout.pt', map_location='cpu'))
obs, _ = env.reset()
obs = np.stack([obs] * 4, 0)
while True:
a = net(torch.tensor(obs, dtype=torch.float32).unsqueeze(0)).argmax().item()
obs, _, done, trunc, _ = env.step(a)
obs = np.append(obs[1:], [obs], 0)
if done:
obs, _ = env.reset()
obs = np.stack([obs] * 4, 0)
6. 下一步:把模板换成别的游戏
- 把
ENV_ID改成PongNoFrameskip-v4,训练 1M 步即可 20+ 分。 - 把网络最后一层改成
n_actions=2就能玩 CartPole。 - 想玩连续控制?把 DQN 换成 SAC(
stable_baselines3.SAC一行搞定)。
7. 小结
| 阶段 | 目标 | 时间 |
|---|---|---|
| 跑通 | 120 行脚本 | 30 分钟 |
| 调优 | 3 个黑魔法 | 1 小时 |
| 迁移 | 换游戏 | 10 分钟 |
深度强化学习没有玄学,只有工程。先跑通,再调参,最后迁移——祝你玩得开心!
Comments NOTHING