After understanding the three basic concepts behind TRPO in Part 1 (MM algorithm, trust region, and importance sampling), we will explain it in details in this last part of TRPO.
In this section, we will define the objective we want to optimize. We will do it mathematically first. The math is harder to understand but we will explain the intuition later which is pretty simple.
What we optimize? Optimize the policy π’ can be reformulated as:
We change the notation to match the TRPO & PPO papers. The reward function notation has changed from η to J. π’ is the variable we want to optimize and π is the old policy. J(π), the expected reward for the old policy π, is a constant and therefore the change above does not alter the solution. In short, maximizing the reward for a policy is the same as maximizing the reward relative to a specific policy.
Now, we need to find the lower bound function needed for the MM algorithm. Its first term is the 𝓛 defined as:
Don’t get scared by 𝓛 now. d is the discounted future state distribution. If γ=1, d is just the state visit frequency under the policy π. A is the advantage function (a.k.a. calibrated expected rewards). But we can simply view 𝓛 as using importance sampling to estimate the advantage function.
The appendix A of the TRPO paper provides a 2-page proof that establishes the following boundary:
where D_KL is the KL-divergence that measures the difference between two data distribution p and q. KL-divergence is defined as:
With some twitting, this is our final lower bound M.
for our objective function
To summarize, we will use MM algorithms to maximize:
The inequality equation below is important because we can establish an upper bound error for the objective calculation. This establishes a trust region on whether we can trust the result.
In fact, with the Lagrangian Duality, our objective is mathematically the same as the following using a trust region constraint.
Recap: Objective function
Now, we establish the objective function that we optimize for each MM iteration. The first objective is referred to as KL penalized and the second one is referred to as KL constrained.
Guaranteed monotonic improvement
The power of TRPO, PPO, and natural policy gradient builds on the concept of guaranteed monotonic improvement. Theoretically, the policy update in each TRPO iteration creates a better policy. Let’s have a simple proof in the guaranteed monotonic improvement next. From here, we can prove
The R.H.S. term below equals to zero when π’ = π. Therefore, the L.H.S. is always greater or equal to 0.
i.e. the new policy is always better than the old one. In fact, the new policy will have greater improvement in the real objective function than the lower bound approximation.
Can we interpret the equation intuitively? We can approximate the expected advantage function locally around the current policy. But the accuracy decreases when the new policy and the current policy diverge from each other. But we can establish an upper bound for the error. Therefore, we can guarantee a policy improvement as long as we optimize the local approximation within a trusted region. Outside this region, the bet is off. Even it may have a better-calculated value, its range of error fails the improvement guarantee. With such a guarantee inside the trust region, we can locate the optimal policy iteratively. So even it takes a while to prove it mathematically, the reasoning is pretty simple.
KL penalized v.s. KL constrained
Mathematically, both KL penalized objective and the KL constrained objective are the same if we have unlimited computational resources. However, in practice, they are not.
C gets very high when γ (discount factor) is close to one and the corresponding gradient step size becomes too small. One solution is to turn both C and δ to tunable hyperparameters.
In practice, δ is much easier to tune than C. δ imposes a hard constraint to control the bad case scenarios in the policy space. It restraints policy changes that can turn destructive. Tuning C is much harder. Empirical results show that it cannot be a fixed value and need to be more adaptive. Therefore, trust region constraint is more popular. As a footnote, this can be fixed if certain enhancements are made to the KL penalty methods (just like the PPO method).
We have defined the optimization problems we want to solve. However, solving them are not easy. Next, we look at how the natural policy gradient method solves it analytically and how TRPO addresses its weakness.
Natural Policy Gradient
Natural Policy Gradient solves the following objective function analytically.
To solve it, we can use Taylor’s series to expand both terms above up to the second-order. But the second-order of 𝓛 is much smaller than the KL-divergence term and will be ignored.
Taking out all the zero terms, it becomes:
where g is the policy gradient (the same we learn in Policy Gradient) and H measure the sensitivity (curvature) of the policy relative to the model parameter θ. Our objective turns to:
This is a quadratic equation and can be solved analytically:
The term below maps the changes we want in the policy space into the corresponding parameter space.
It provides a solution that works regardless of how we parameterize the policy model (proof). H has different values in different parametrization. But it generates the parameter changes (Δθ) that can map to the same policy change. Therefore, our solution is model invariant. Gradient descent approximates the space to be locally flat. This corresponding to the Euclidean distance (squared root distance) that we are familiar with. If we replace H above with the Euclidean distance, we will realize that the parameter updates are not model invariant. The calculated parameter changes for different models will refer to different policy changes. This provides the theoretical reason why the natural policy gradient is better than the gradient ascent.
Natural policy gradient is a second order optimization that is more accurate and works regardless of how the model is parameterized (model invariant).
H is a Hessian matrix in the generic form of:
This particular matrix measures the curvature (second-order derivative) of the log probability of the policy. There is a special term for it: Fisher Information Matrix (FIM). Many literatures use F instead of H to represent the FIM. In this article, both H and F refer to FIM.
F can also be computed with the expression below:
The following is the pseudo code. We can samples trajectories from the current policy, and use them to compute the advantage functions. Then we compute the policy gradient and we use the equation above to compute FIM.
The caveat of natural policy gradient
Finding the inverse of H is expensive if the policy is parametrized by many parameters, in particular for the deep network. Also, the inverse is often numerical unstable (errors are easily magnified by data imprecision).
Truncated Natural Policy Gradient
Instead of finding the inverse of FIM, we want to compute the following combined term directly
where x can be solved as:
We convert this equation into an optimization problem for a quadratic equation:
In short, we can transform our problem as optimizing the quadratic equation below:
To optimize this, we apply the conjugate gradient method. The concept is very similar to the gradient ascent but can be done in fewer iterations.
In gradient ascent, we always follow the steepest gradient. In our example, we want to ascend from the yellow spot to the red. Let’s say our first move will go to the right according to the gradient contour. To follow the steepest gradient again, the second move may go up and slightly to the left. We may ask one question here. Why does the second move go left? It sounds like undoing some progress made by the first move even every move get closer to the red spot.
If the objective function is quadratic, we can avoid this inefficiency using the conjugate gradient method. If the model has N parameters, we can find the optimal point in at most N ascent. In the first move, we follow the deepest gradient direction d and settles in the optimal point in that direction (the peak of the objective function in the direction of d). Then for the next search direction dj, it must be A-orthogonal (conjugate) to all previous direction di. Mathematically, it means
We can imagine those vectors as perpendicular to each other after some transformation with A.
Conjugate gradient is about finding a search direction every time that is A-orthogonal to all previous directions. So we make sure we are not undoing part of the progress previously make. Then we move in the new direction to the optimal point along this search line.
These vectors (directions) are independent of each other and span an N-dimensional space. So we can solve the problem at most n steps. The explanation sounds complicated but actually, the algorithm has much lower complexity than computing the inverse of F. Now, we can compute
in N iterations. Truncated Natural Policy Gradient (TNPG) uses the conjugate method below (the underlined red line) to replace the requirement to compute the inverse of FIM.
Trust region policy optimization TRPO
Finally, we will put everything together for TRPO. TRPO applies the conjugate gradient method to the natural policy gradient. But it is not enough. If something is too good to be true, it may not. The trusted region for the natural policy gradient is very small. We relax it to a bigger tunable value. In addition, the quadratic approximation we made also decreases the accuracy. These factors may reintroduce problem to the policy updates. In some training iterations, performance may degrade. One mitigation is to verify the new policy first before commits the change. To do that, we verify
- KL-divergence for the new policy θ is ≤ δ and
- 𝓛(θ) ≥ 0.
If the verification fails, we will decay the natural policy gradient by a factor of α (0<α<1) until the new parameters meet the requirement above. The following describes the line search solution above:
TRPO combines both Truncated Natural Policy Gradient (using the conjugate gradient) and the backtracking line search:
TRPO minimizes the quadratic equation to approximate the inverse of F.
But, for each model parameter update, we still need to compute F. This hurts scalability:
- Computing F every time for the current policy model is expensive, and
- It requires a large batch of rollouts in to approximate F accurately.
In general, TRPO is less sample efficient compared to other policy gradient methods trained with first-order optimizers like Adam. Because of this scalability issue, TRPO is not practical for large deep networks. PPO & ACKTR are introduced to address these problems.
It is a long journey but congratulation you make it! I hope you have a far deeper understanding of TRPO now.