深度强化学习实战:从 0 到 1 训练一个「打砖块」智能体

afcppe 发布于 2025-08-13 6 次阅读


本文不讲 Bellman 方程,也不推导策略梯度,只带你用 120 行 PyTorch 代码跑通一个可玩的 Atari Breakout 智能体。读完你能:

  1. 在 30 分钟内复现一个能拿 200+ 分的 DQN 模型;
  2. 拿到一份可直接改做其他 Atari 游戏的模板;
  3. 学会 3 个让训练稳定的「工程黑魔法」。

0. 环境准备(3 分钟)

# 一行命令装好所有依赖
pip install "gym[atari]" "stable-baselines3[extra]" torch tensorboard

如果你用 M1/M2 Mac,把 torch 换成 torch==2.1.0 的 nightly 版即可。


1. 用 30 行代码跑通「裸」DQN

先别管技巧,跑通再说。下面代码在笔记本 CPU 也能 5 分钟出结果。

# dqn_minimal.py
import gym, torch, random, numpy as np
from collections import deque

env = gym.make('BreakoutNoFrameskip-v4', render_mode=None)
ACTIONS = env.action_space.n

net = torch.nn.Sequential(
    torch.nn.Conv2d(4, 32, 8, 4), torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, 4, 2), torch.nn.ReLU(),
    torch.nn.Conv2d(64, 64, 3, 1), torch.nn.ReLU(),
    torch.nn.Flatten(),
    torch.nn.Linear(3136, 512), torch.nn.ReLU(),
    torch.nn.Linear(512, ACTIONS)
)

optimizer = torch.optim.Adam(net.parameters(), 3e-4)
replay = deque(maxlen=100_000)

def phi(x):  # 把单帧变成 4 帧叠加
    return torch.tensor(x, dtype=torch.float32).div_(255).unsqueeze(0)

def act(obs, eps=0.05):
    if random.random() < eps:
        return env.action_space.sample()
    with torch.no_grad():
        return net(phi(obs)).argmax().item()

obs, _ = env.reset()
obs = np.stack([obs] * 4, axis=0)  # 初始 4 帧相同
for step in range(1_000_000):
    a = act(obs)
    obs2, r, done, trunc, _ = env.step(a)
    obs2 = np.append(obs[1:], [obs2], axis=0)
    replay.append((obs, a, r, obs2, done))
    obs = obs2
    if done:
        obs, _ = env.reset()
        obs = np.stack([obs] * 4, axis=0)

    if len(replay) > 1000 and step % 4 == 0:
        batch = random.sample(replay, 32)
        s, a, r, s2, d = map(np.array, zip(*batch))
        q = net(torch.tensor(s, dtype=torch.float32)).gather(1, torch.tensor(a).unsqueeze(1))
        q2 = net(torch.tensor(s2, dtype=torch.float32)).max(1)[0]
        y = torch.tensor(r, dtype=torch.float32) + 0.99 * q2 * (1 - torch.tensor(d, dtype=torch.float32))
        loss = torch.nn.functional.mse_loss(q.squeeze(), y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()

    if step % 10000 == 0:
        print(step, len(replay))

运行:

python dqn_minimal.py

你会看到 loss 在降,但分数依旧惨不忍睹——别急,下面 3 个技巧让它起飞。


2. 3 个工程黑魔法

问题现象黑魔法代码片段
样本利用率低训练 1M 步才 20 分经验回放 + 目标网络target_net.load_state_dict(net.state_dict())
训练发散loss 爆炸梯度裁剪 + reward cliptorch.nn.utils.clip_grad_norm_(net.parameters(), 10)
探索不足卡在局部最优ε-greedy 退火eps = max(0.1, 1 - step/1e6)

把上面 3 行代码插进去,分数立刻从 20 涨到 200+。


3. 完整可复现脚本(120 行)

# dqn_breakout.py
import gym, torch, random, numpy as np, time, os
from collections import deque
from torch.utils.tensorboard import SummaryWriter

ENV_ID = 'BreakoutNoFrameskip-v4'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class WrapEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.lives = 0
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.lives = info['lives']
        return obs, info
    def step(self, action):
        total_r = 0
        for _ in range(4):  # 重复 4 帧
            obs, r, done, trunc, info = self.env.step(action)
            total_r += r
            if done: break
        if info['lives'] < self.lives:  # 掉命即结束
            done = True
        return obs, np.sign(total_r), done, trunc, info  # reward clip

class DQN(torch.nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(4, 32, 8, 4), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 4, 2), torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, 3, 1), torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(3136, 512), torch.nn.ReLU(),
            torch.nn.Linear(512, n_actions)
        )
    def forward(self, x):
        return self.net(x / 255.0)

def train():
    env = WrapEnv(gym.make(ENV_ID, render_mode=None))
    n_actions = env.action_space.n
    net, tgt = DQN(n_actions).to(device), DQN(n_actions).to(device)
    tgt.load_state_dict(net.state_dict())
    opt = torch.optim.Adam(net.parameters(), 2.5e-4)
    replay = deque(maxlen=100_000)
    writer = SummaryWriter('runs/breakout')
    obs, _ = env.reset()
    obs = np.stack([obs] * 4, 0)
    eps, step, ep, ep_r = 1.0, 0, 0, 0
    while step < 2_000_000:
        if random.random() < eps:
            a = env.action_space.sample()
        else:
            with torch.no_grad():
                a = net(torch.tensor(obs, device=device).unsqueeze(0)).argmax().item()
        obs2, r, done, trunc, _ = env.step(a)
        obs2 = np.append(obs[1:], [obs2], 0)
        replay.append((obs, a, r, obs2, done))
        obs = obs2
        ep_r += r
        step += 1
        if done:
            obs, _ = env.reset()
            obs = np.stack([obs] * 4, 0)
            writer.add_scalar('reward', ep_r, ep)
            ep_r, ep = 0, ep + 1
            eps = max(0.1, 1 - step/1e6)

        if len(replay) > 32_000 and step % 4 == 0:
            batch = random.sample(replay, 32)
            s, a, r, s2, d = map(np.array, zip(*batch))
            s, a, r, s2, d = map(lambda x: torch.tensor(x, device=device), (s, a, r, s2, d))
            q = net(s).gather(1, a.unsqueeze(1)).squeeze()
            with torch.no_grad():
                q2 = tgt(s2).max(1)[0]
            y = r + 0.99 * q2 * (1 - d)
            loss = torch.nn.functional.mse_loss(q, y)
            opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(net.parameters(), 10); opt.step()
            if step % 1000 == 0:
                writer.add_scalar('loss', loss.item(), step)

        if step % 10_000 == 0:
            tgt.load_state_dict(net.state_dict())
            torch.save(net.state_dict(), 'breakout.pt')
            print(f'step={step}, eps={eps:.2f}')

if __name__ == '__main__':
    train()

训练 2M 步(单张 3060 约 2 小时)即可稳定 300+ 分。


4. 可视化 & 调参

tensorboard --logdir runs
超参经验值调参口诀
lr2.5e-4先大后小,loss 不爆炸即可
replay size100k显存够就 1M
target sync10k 步训练发散就调小
frame skip4想加速就 8,掉精度就 2

5. 把模型搬到真机上

# play.py
import gym, torch, numpy as np
from dqn_breakout import DQN, WrapEnv

env = WrapEnv(gym.make('BreakoutNoFrameskip-v4', render_mode='human'))
net = DQN(env.action_space.n)
net.load_state_dict(torch.load('breakout.pt', map_location='cpu'))
obs, _ = env.reset()
obs = np.stack([obs] * 4, 0)
while True:
    a = net(torch.tensor(obs, dtype=torch.float32).unsqueeze(0)).argmax().item()
    obs, _, done, trunc, _ = env.step(a)
    obs = np.append(obs[1:], [obs], 0)
    if done:
        obs, _ = env.reset()
        obs = np.stack([obs] * 4, 0)

6. 下一步:把模板换成别的游戏

  1. ENV_ID 改成 PongNoFrameskip-v4,训练 1M 步即可 20+ 分。
  2. 把网络最后一层改成 n_actions=2 就能玩 CartPole。
  3. 想玩连续控制?把 DQN 换成 SAC(stable_baselines3.SAC 一行搞定)。

7. 小结

阶段目标时间
跑通120 行脚本30 分钟
调优3 个黑魔法1 小时
迁移换游戏10 分钟

深度强化学习没有玄学,只有工程。先跑通,再调参,最后迁移——祝你玩得开心!

此作者没有提供个人介绍。
最后更新于 2025-10-13