Policy gradient method

From HandWiki
Short description: Class of reinforcement learning algorithms


Policy gradient methods are a class of reinforcement learning algorithms.

Policy gradient methods are a sub-class of policy optimization methods. Unlike value-based methods which learn a value function to derive a policy, policy optimization methods directly learn a policy function π that selects actions without consulting a value function. For policy gradient to apply, the policy function πθ is parameterized by a differentiable parameter θ.[1]

Overview

In policy-based RL, the actor is a parameterized policy function πθ, where θ are the parameters of the actor. The actor takes as argument the state of the environment s and produces a probability distribution πθ(s).

If the action space is discrete, then aπθ(as)=1. If the action space is continuous, then aπθ(as)da=1.

The goal of policy optimization is to find some θ that maximizes the expected episodic reward J(θ):J(θ)=𝔼πθ[t0:TγtRt|S0=s0]where γ is the discount factor, Rt is the reward at step t, s0 is the starting state, and T is the time-horizon (which can be infinite).

The policy gradient is defined as θJ(θ). Different policy gradient methods stochastically estimate the policy gradient in different ways. The goal of any policy gradient method is to iteratively maximize J(θ) by gradient ascent. Since the key part of any policy gradient method is the stochastic estimation of the policy gradient, they are also studied under the title of "Monte Carlo gradient estimation".[2]

REINFORCE

Policy gradient

The REINFORCE algorithm, introduced by Ronald J. Williams in 1992, was the first policy gradient method.[3] It is based on the identity for the policy gradientθJ(θ)=𝔼πθ[t0:Tθlnπθ(AtSt)t0:T(γtRt)|S0=s0] which can be improved via the "causality trick"[1]θJ(θ)=𝔼πθ[t0:Tθlnπθ(AtSt)τt:T(γτRτ)|S0=s0]

Lemma — The expectation of the score function is zero, conditional on any present or past state. That is, for any 0ijT and any state si, we have 𝔼πθ[θlnπθ(Aj|Sj)|Si=si]=0.

Further, if Ψi is a random variable that is independent of Ai,Si+1,Ai+1,, then 𝔼πθ[θlnπθ(Aj|Sj)Ψi|Si=si]=0.

Proofs

Thus, we have an unbiased estimator of the policy gradient:

θJ(θ)1Nn=1N[t0:Tθlnπθ(At,nSt,n)τt:T(γτtRτ,n)]

where the index

n

ranges over

N

rollout trajectories using the policy

πθ

.

The score function θlnπθ(AtSt) can be interpreted as the direction in the parameter space that increases the probability of taking action At in state St. The policy gradient, then, is a weighted average of all possible directions to increase the probability of taking any action in any state, but weighted by reward signals, so that if taking a certain action in a certain state is associated with high reward, then that direction would be highly reinforced, and vice versa.

Algorithm

The REINFORCE algorithm is a loop:

  1. Rollout N trajectories in the environment, using πθt as the policy function.
  2. Compute the policy gradient estimation: gi1Nn=1N[t0:Tθtlnπθ(At,nSt,n)τt:T(γτRτ,n)]
  3. Update the policy by gradient ascent: θi+1θi+αigi

Here, αi is the learning rate at update step i.

Variance reduction

REINFORCE is an on-policy algorithm, meaning that the trajectories used for the update must be sampled from the current policy πθ. This can lead to high variance in the updates, as the returns R(τ) can vary significantly between trajectories. Many variants of REINFORCE have been introduced, under the title of variance reduction.

REINFORCE with baseline

A common way for reducing variance is the REINFORCE with baseline algorithm, based on the following identity:θJ(θ)=𝔼πθ[t0:Tθlnπθ(At|St)(τt:T(γτRτ)b(St))|S0=s0]for any function b:States. This can be proven by applying the previous lemma.

The algorithm uses the modified gradient estimatorgi1Nn=1N[t0:Tθtlnπθ(At,n|St,n)(τt:T(γτRτ,n)bi(St,n))] and the original REINFORCE algorithm is the special case where bi0.

Actor-critic methods

If bi is chosen well, such that bi(St)τt:T(γτRτ)=γtVπθi(St), this could significantly decrease variance in the gradient estimation. That is, the baseline should be as close to the value function Vπθi(St) as possible, approaching the ideal of:θJ(θ)=𝔼πθ[t0:Tθlnπθ(At|St)(τt:T(γτRτ)γtVπθ(St))|S0=s0]Note that, as the policy πθt updates, the value function Vπθi(St) updates as well, so the baseline should also be updated. One common approach is to train a separate function that estimates the value function, and use that as the baseline. This is one of the actor-critic methods, where the policy function is the actor and the value function is the critic.

The Q-function Qπ can also be used as the critic, sinceθJ(θ)=Eπθ[0tTγtθlnπθ(At|St)Qπθ(St,At)|S0=s0] by a similar argument using the tower law.

Subtracting the value function as a baseline, we find that the advantage function Aπ(S,A)=Qπ(S,A)Vπ(S) can be used as the critic as well:θJ(θ)=Eπθ[0tTγtθlnπθ(At|St)Aπθ(St,At)|S0=s0]In summary, there are many unbiased estimators for θJθ, all in the form of: θJ(θ)=Eπθ[0tTθlnπθ(At|St)Ψt|S0=s0] where Ψt is any linear sum of the following terms:

  • 0τT(γτRτ): never used.
  • γttτT(γτtRτ): used by the REINFORCE algorithm.
  • γttτT(γτtRτ)b(St): used by the REINFORCE with baseline algorithm.
  • γt(Rt+γVπθ(St+1)Vπθ(St)): 1-step TD learning.
  • γtQπθ(St,At).
  • γtAπθ(St,At).

Some more possible Ψt are as follows, with very similar proofs.

  • γt(Rt+γRt+1+γ2Vπθ(St+2)Vπθ(St)): 2-step TD learning.
  • γt(k=0n1γkRt+k+γnVπθ(St+n)Vπθ(St)): n-step TD learning.
  • γtn=1λn11λ(k=0n1γkRt+k+γnVπθ(St+n)Vπθ(St)): TD(λ) learning, also known as GAE (generalized advantage estimate).[4] This is obtained by an exponentially decaying sum of the n-step TD learning ones.

Natural policy gradient

The natural policy gradient method is a variant of the policy gradient method, proposed by Sham Kakade in 2001.[5] Unlike standard policy gradient methods, which depend on the choice of parameters θ (making updates coordinate-dependent), the natural policy gradient aims to provide a coordinate-free update, which is geometrically "natural".

Motivation

Standard policy gradient updates θi+1=θi+αθJ(θi) solve a constrained optimization problem:{maxθi+1J(θi)+(θi+1θi)TθJ(θi)θi+1θiαθJ(θi) While the objective (linearized improvement) is geometrically meaningful, the Euclidean constraint θi+1θi introduces coordinate dependence. To address this, the natural policy gradient replaces the Euclidean constraint with a Kullback–Leibler divergence (KL) constraint:{maxθi+1J(θi)+(θi+1θi)TθJ(θi)D¯KL(πθi+1πθi)ϵwhere the KL divergence between two policies is averaged over the state distribution under policy πθi. That is,D¯KL(πθi+1πθi):=𝔼sπθi[DKL(πθi+1(|s)πθi(|s))] This ensures updates are invariant to invertible affine parameter transformations.

Fisher information approximation

For small ϵ, the KL divergence is approximated by the Fisher information metric:D¯KL(πθi+1πθi)12(θi+1θi)TF(θi)(θi+1θi)where F(θ) is the Fisher information matrix of the policy, defined as:F(θ)=𝔼s,aπθ[θlnπθ(a|s)(θlnπθ(a|s))T] This transforms the problem into a problem in quadratic programming, yielding the natural policy gradient update:θi+1=θi+αF(θi)1θJ(θi)The step size α is typically adjusted to maintain the KL constraint, with α2ϵ(θJ(θi))TF(θi)1θJ(θi).

Inverting F(θ) is computationally intensive, especially for high-dimensional parameters (e.g., neural networks). Practical implementations often use approximations.

Trust Region Policy Optimization (TRPO)

Trust Region Policy Optimization (TRPO) is a policy gradient method that extends the natural policy gradient approach by enforcing a trust region constraint on policy updates.[6] Developed by Schulman et al. in 2015, TRPO improves upon the natural policy gradient method.

The natural gradient descent is theoretically optimal, if the objective is truly a quadratic function, but this is only an approximation. TRPO's line search and KL constraint attempts to restrict the solution to within a "trust region" in which this approximation does not break down. This makes TRPO more robust in practice.

Formulation

Like natural policy gradient, TRPO iteratively updates the policy parameters θ by solving a constrained optimization problem specified coordinate-free:{maxθL(θ,θi)D¯KL(πθπθi)ϵwhere

  • L(θ,θi)=𝔼s,aπθi[πθ(a|s)πθi(a|s)Aπθi(s,a)] is the surrogate advantage, measuring the performance of πθ relative to the old policy πθi.
  • ϵ is the trust region radius.

Note that in general, other surrogate advantages are possible:L(θ,θi)=𝔼s,aπθi[πθ(a|s)πθi(a|s)Ψπθi(s,a)]where Ψ is any linear sum of the previously mentioned type. Indeed, OpenAI recommended using the Generalized Advantage Estimate, instead of the plain advantage Aπθ.

The surrogate advantage L(θ,θt) is designed to align with the policy gradient θJ(θ). Specifically, when θ=θt, θL(θ,θt) equals the policy gradient derived from the advantage function: θJ(θ)=𝔼(s,a)πθ[θlnπθ(a|s)Aπθ(s,a)]=θL(θ,θt)However, when θθi, this is not necessarily true. Thus it is a "surrogate" of the real objective.

As with natural policy gradient, for small policy updates, TRPO approximates the surrogate advantage and KL divergence using Taylor expansions around θt:L(θ,θi)gT(θθi),D¯KL(πθπθi)12(θθi)TH(θθi), where:

  • g=θL(θ,θi)|θ=θi is the policy gradient.
  • F=θ2D¯KL(πθπθi)|θ=θi is the Fisher information matrix.

This reduces the problem to a quadratic optimization, yielding the natural policy gradient update: θi+1=θi+2ϵgTF1gF1g.So far, this is essentially the same as natural gradient method. However, TRPO improves upon it by two modifications:

  • Use conjugate gradient method to solve for x in Fx=g iteratively without explicit matrix inversion.
  • Use backtracking line search to ensure the trust-region constraint is satisfied. Specifically, it backtracks the step size to ensure the KL constraint and policy improvement. That is, it tests each of the following test-solutionsθi+1=θi+2ϵxTFxx,θi+α2ϵxTFxx,θi+α22ϵxTFxx, until it finds one that both satisfies the KL constraint D¯KL(πθi+1πθi)ϵ and results in a higher L(θi+1,θi)L(θi,θi). Here, α(0,1) is the backtracking coefficient.

Proximal Policy Optimization (PPO)

A further improvement is proximal policy optimization (PPO), which avoids even computing F(θ) and F(θ)1 via a first-order approximation using clipped probability ratios.[7]

Specifically, instead of maximizing the surrogate advantagemaxθL(θ,θt)=𝔼s,aπθt[πθ(a|s)πθt(a|s)Aπθt(s,a)] under a KL divergence constraint, it directly inserts the constraint into the surrogate advantage:maxθ𝔼s,aπθt[{min(πθ(a|s)πθt(a|s),1+ϵ)Aπθt(s,a) if Aπθt(s,a)>0max(πθ(a|s)πθt(a|s),1ϵ)Aπθt(s,a) if Aπθt(s,a)<0] and PPO maximizes the surrogate advantage by stochastic gradient descent, as usual.

In words, gradient-ascending the new surrogate advantage function means that, at some state s,a, if the advantage is positive: Aπθt(s,a)>0, then the gradient should direct θ towards the direction that increases the probability of performing action a under the state s. However, as soon as θ has changed so much that πθ(a|s)(1+ϵ)πθt(a|s), then the gradient should stop pointing it in that direction. And similarly if Aπθt(s,a)<0. Thus, PPO avoids pushing the parameter update too hard, and avoids changing the policy too much.

To be more precise, to update θt to θt+1 requires multiple update steps on the same batch of data. It would initialize θ=θt, then repeatedly apply gradient descent (such as the Adam optimizer) to update θ until the surrogate advantage has stabilized. It would then assign θt+1 to θ, and do it again.

During this inner-loop, the first update to θ would not hit the 1ϵ,1+ϵ bounds, but as θ is updated further and further away from θt, it eventually starts hitting the bounds. For each such bound hit, the corresponding gradient becomes zero, and thus PPO avoid updating θ too far away from θt.

This is important, because the surrogate loss assumes that the state-action pair s,a is sampled from what the agent would see if the agent runs the policy πθt, but policy gradient should be on-policy. So, as θ changes, the surrogate loss becomes more and more off-policy. This is why keeping θ proximal to θt is necessary.

If there is a reference policy πref that the trained policy should not diverge too far from, then additional KL divergence penalty can be added:β𝔼s,aπθt[log(πθ(a|s)πref(a|s))]where β adjusts the strength of the penalty. This has been used in training reasoning language models with reinforcement learning from human feedback.[8] The KL divergence penalty term can be estimated with lower variance using the equivalent form (see f-divergence for details):[9]β𝔼s,aπθt[log(πθ(a|s)πref(a|s))+πref(a|s)πθ(a|s)1]

Group Relative Policy Optimization (GRPO)

The Group Relative Policy Optimization (GRPO) is a minor variant of PPO that omits the value function estimator V. Instead, for each state s, it samples multiple actions a1,,aG from the policy πθt, then calculate the group-relative advantage[9]Aπθt(s,aj)=r(s,aj)μσ where μ,σ are the mean and standard deviation of r(s,a1),,r(s,aG). That is, it is the standard score of the rewards.

Then, it maximizes the PPO objective, averaged over all actions:maxθ1Gi=1G𝔼(s,a1,,aG)πθt[{min(πθ(ai|s)πθt(ai|s),1+ϵ)Aπθt(s,ai) if Aπθt(s,ai)>0max(πθ(ai|s)πθt(ai|s),1ϵ)Aπθt(s,ai) if Aπθt(s,ai)<0]Intuitively, each policy update step in GRPO makes the policy more likely to respond to each state with an action that performed relatively better than other actions tried at that state, and less likely to respond with one that performed relatively worse.

As before, the KL penalty term can be applied to encourage the trained policy to stay close to a reference policy. GRPO was first proposed in the context of training reasoning language models by researchers at DeepSeek.[9]

Policy Optimization and the Mirror Descent perspective (MDPO)

Methods like TRPO, PPO and natural policy gradient share a common idea - while the policy should be updated in the direction of the policy gradient, the update should be done in a safe and stable manner, typically measured by some distance with respect to the policy before the update.

A similar notion of update stability is found in proximal convex optimization techniques like Mirror Descent.[10] There, 𝐱, the proposed minimizer of f in some constraint set 𝒞, is iteratively updated in the direction of the gradient f, with a proximity penalty with respect to the current 𝐱t measured by some Bregman divergence Bω, which can formalized by the following formula:𝐱t+1argmin𝐱𝒞f(𝐱t)T(𝐱𝐱t)+1ηtBω(x,xt), where ηt controls the proximity between consecutive iterates, similar to the learning rate in gradient descent.

This leads to reconsidering the policy update procedure as an optimization procedure aimed at finding an optimal policy, in the (non-convex) optimization landscape of the underlying Markov decision process (MDP). This optimization viewpoint of using the policy gradient is termed Mirror Descent Policy Optimization (MDPO),[11][12] leading to the following update when the KL is the chosen Bregman divergence:πt+1argmaxπ𝔼s,aπ[Aπt(s,a)]+1ηtDKL(π||πt)With a parameterized policy πθ, the MDPO loss becomes:maxθL(θ,θt)=𝔼s,aπθt[πθ(a|s)πθt(a|s)Aπθt(s,a)]+1ηtDKL(πθ||πθt)This objective can be used together with other common techniques like the clipping done in PPO. In fact, the KL divergence penalty also appears in the original PPO paper,[7] suggesting the MDPO perspective as a theoretical unification of the main derivation concepts behind many concurrent policy gradient techniques.

See also

References

  1. 1.0 1.1 Sutton, Richard S; McAllester, David; Singh, Satinder; Mansour, Yishay (1999). "Policy Gradient Methods for Reinforcement Learning with Function Approximation". Advances in Neural Information Processing Systems (MIT Press) 12. https://proceedings.neurips.cc/paper_files/paper/1999/hash/464d828b85b0bed98e80ade0a5c43b0f-Abstract.html. 
  2. Mohamed, Shakir; Rosca, Mihaela; Figurnov, Michael; Mnih, Andriy (2020). "Monte Carlo Gradient Estimation in Machine Learning". Journal of Machine Learning Research 21 (132): 1–62. ISSN 1533-7928. https://www.jmlr.org/papers/v21/19-346.html. 
  3. Williams, Ronald J. (May 1992). "Simple statistical gradient-following algorithms for connectionist reinforcement learning" (in en). Machine Learning 8 (3–4): 229–256. doi:10.1007/BF00992696. ISSN 0885-6125. http://link.springer.com/10.1007/BF00992696. 
  4. Schulman, John; Moritz, Philipp; Levine, Sergey; Jordan, Michael; Abbeel, Pieter (2018-10-20). "High-Dimensional Continuous Control Using Generalized Advantage Estimation". arXiv:1506.02438 [cs.LG].
  5. Kakade, Sham M (2001). "A Natural Policy Gradient". Advances in Neural Information Processing Systems (MIT Press) 14. https://proceedings.neurips.cc/paper_files/paper/2001/hash/4b86abe48d358ecf194c56c69108433e-Abstract.html. 
  6. Schulman, John; Levine, Sergey; Moritz, Philipp; Jordan, Michael; Abbeel, Pieter (2015-07-06). "Trust region policy optimization". Proceedings of the 32nd International Conference on International Conference on Machine Learning (Lille, France: JMLR.org) 37: 1889–1897. https://dl.acm.org/doi/10.5555/3045118.3045319. 
  7. 7.0 7.1 Schulman, John; Wolski, Filip; Dhariwal, Prafulla; Radford, Alec; Klimov, Oleg (2017-08-28). "Proximal Policy Optimization Algorithms". arXiv:1707.06347 [cs.LG].
  8. Nisan Stiennon; Long Ouyang; Jeffrey Wu; Daniel Ziegler; Ryan Lowe; Chelsea Voss; Alec Radford; Dario Amodei et al. (2020). "Learning to summarize with human feedback" (in en). Advances in Neural Information Processing Systems 33. https://proceedings.neurips.cc/paper/2020/hash/1f89885d556929e98d3ef9b86448f951-Abstract.html. 
  9. 9.0 9.1 9.2 Shao, Zhihong; Wang, Peiyi; Zhu, Qihao; Xu, Runxin; Song, Junxiao; Bi, Xiao; Zhang, Haowei; Zhang, Mingchuan; Li, Y. K. (2024-04-27). "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models". arXiv:2402.03300 [cs.CL].
  10. Arkadi Nemirovsky and David Yudin. Problem Complexity and Method Efficiency in Optimization. John Wiley & Sons, 1983.
  11. Shani, Lior; Efroni, Yonathan; Mannor, Shie (2020-04-03). "Adaptive Trust Region Policy Optimization: Global Convergence and Faster Rates for Regularized MDPS". Proceedings of the AAAI Conference on Artificial Intelligence 34 (4): 5668–5675. doi:10.1609/aaai.v34i04.6021. ISSN 2374-3468. https://doi.org/10.1609/aaai.v34i04.6021. 
  12. Tomar, Manan; Shani, Lior; Efroni, Yonathan; Ghavamzadeh, Mohammad (2020-05-20). "Mirror Descent Policy Optimization". arXiv:2005.09814v5 [cs.LG].