告别表格,用神经网络玩转策略梯度:从REINFORCE算法到PyTorch实战
从表格到神经网络策略梯度实战与REINFORCE算法深度解析在强化学习的演进历程中策略表示方式经历了从离散表格到连续函数的关键跨越。传统表格法在面对高维状态空间时捉襟见肘而神经网络等函数近似器的引入不仅解决了维度灾难问题更开启了端到端策略学习的新纪元。本文将带您深入策略梯度的核心原理并通过PyTorch实战演示如何构建智能体解决经典控制问题。1. 策略表示从离散到连续的范式转移1.1 表格法的局限与突破传统表格策略表示将每个状态-动作对的概率存储在二维矩阵中这种方法的优势是直观且易于理解。例如在简单的网格世界中我们可以直接通过坐标索引获取策略# 表格策略示例 policy_table { (0,0): {up:0.6, right:0.4}, (0,1): {down:0.8, left:0.2} }但当状态空间增大时表格法暴露出三大致命缺陷存储瓶颈状态数量呈指数增长时内存需求迅速膨胀泛化困难相似状态无法共享经验每个状态需单独学习探索低效难以自动发现状态间的潜在关联1.2 神经网络策略的架构设计现代深度强化学习采用神经网络参数化策略π(a|s;θ)其典型架构包含输入层状态特征向量如CartPole中的位置、速度等隐藏层3-5层全连接网络使用ReLU激活函数输出层Softmax激活确保动作概率归一化import torch.nn as nn class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.fc1 nn.Linear(state_dim, 64) self.fc2 nn.Linear(64, 64) self.fc3 nn.Linear(64, action_dim) def forward(self, x): x torch.relu(self.fc1(x)) x torch.relu(self.fc2(x)) return torch.softmax(self.fc3(x), dim-1)这种参数化表示具有自动特征提取能力相近状态自然获得相似策略极大提升了学习效率。2. 策略梯度定理数学基础与直观理解2.1 目标函数的构建艺术策略梯度方法的核心是优化策略参数θ使目标函数J(θ)最大化。实践中常用的目标函数包括目标函数类型数学表达式适用场景初始状态价值J(θ)vπ(s0)明确初始状态的任务平均单步奖励J(θ)Σdπ(s)rπ(s)即时奖励敏感的任务折扣状态价值J(θ)Σdπ(s)vπ(s)长期回报重要的任务其中dπ(s)表示策略π下的稳态状态分布。2.2 策略梯度定理的推导通过对数技巧和蒙特卡洛采样我们得到策略梯度的通用表达式∇J(θ) [∇logπ(a|s;θ) * Qπ(s,a)]这一优美公式揭示更新方向沿Q值增长方向调整策略更新幅度与动作概率成反比保证探索性数学性质无偏但高方差需配合基线降低方差提示实际实现时通常会减去状态值V(s)作为基线保持期望不变同时降低方差3. REINFORCE算法原始而强大的蒙特卡洛方法3.1 算法流程与实现细节REINFORCE作为最基础的策略梯度算法其完整流程包含采样完整轨迹τ(s0,a0,r1,...,sT)计算各时刻的回报GtΣγ^(k-t)rk估计策略梯度∇J(θ)≈ΣGt∇logπ(at|st;θ)参数更新θ←θα∇J(θ)PyTorch实现核心代码如下def train_episode(env, policy, optimizer, gamma0.99): states, actions, rewards [], [], [] state env.reset() # 轨迹采样 while True: probs policy(torch.FloatTensor(state)) action torch.multinomial(probs, 1).item() next_state, reward, done, _ env.step(action) states.append(state) actions.append(action) rewards.append(reward) state next_state if done: break # 计算回报 returns [] G 0 for r in reversed(rewards): G r gamma * G returns.insert(0, G) # 策略更新 optimizer.zero_grad() loss 0 for s, a, G in zip(states, actions, returns): prob policy(torch.FloatTensor(s))[a] loss -torch.log(prob) * G loss.backward() optimizer.step() return sum(rewards)3.2 训练技巧与调参经验经过大量实验我们总结出以下实用技巧奖励标准化减去均值除以标准差稳定训练熵正则化添加βH(π)项防止策略过早收敛学习率衰减从1e-3开始每万步减半批量训练并行多个环境收集样本提升效率在CartPole环境中典型训练曲线表现为前1000步随机探索奖励波动大1000-5000步快速上升期策略明显改善5000步后收敛到最优奖励保持最大值4. 超越REINFORCE策略梯度的进阶方向4.1 方差缩减技术原始REINFORCE的高方差问题可通过以下方法缓解技术实现方式效果提升基线减法使用状态值函数V(s)作为基线30-50%优势函数A(s,a)Q(s,a)-V(s)50-70%广义优势估计GAE(λ)平衡偏差与方差70-90%4.2 信任域与自然梯度为保障策略更新的稳定性现代方法引入PPO通过剪切概率比限制更新幅度TRPO求解带约束的优化问题自然梯度考虑参数空间的几何结构这些方法在MuJoCo等复杂环境中展现出显著优势训练效率可提升2-3倍。5. 实战CartPole从零构建智能体5.1 环境配置与超参数设置使用Gymnasium创建环境并初始化关键参数import gymnasium as gym env gym.make(CartPole-v1) config { hidden_size: 64, learning_rate: 1e-3, gamma: 0.99, entropy_coef: 0.01, num_episodes: 3000 }5.2 完整训练流程结合前述技术的完整训练脚本结构policy PolicyNetwork(env.observation_space.shape[0], env.action_space.n) optimizer torch.optim.Adam(policy.parameters(), lrconfig[learning_rate]) for ep in range(config[num_episodes]): # 采样轨迹 states, actions, rewards [], [], [] state, _ env.reset() while True: probs policy(torch.FloatTensor(state)) action torch.multinomial(probs, 1).item() next_state, reward, terminated, truncated, _ env.step(action) done terminated or truncated # 存储转移 states.append(state) actions.append(action) rewards.append(reward) state next_state if done: break # 计算回报与优势 returns compute_returns(rewards, config[gamma]) advantages compute_advantages(returns, states) # 策略更新 update_policy(policy, optimizer, states, actions, advantages) # 定期测试与保存 if ep % 100 0: test_performance(env, policy)5.3 典型问题排查指南当训练出现问题时可依次检查梯度消失检查网络初始化适当增大初始方差过早收敛增加熵正则项系数振荡剧烈减小学习率或增大批量大小性能停滞尝试更复杂的网络结构在实现过程中我发现使用Tanh激活函数比ReLU在策略网络中表现更稳定特别是在训练初期。另一个实用技巧是在前1000步保持较高探索率之后逐步降低这种课程学习策略能显著提升最终性能。