Conservative Q learning for Offline Reinforcement Learning

11 minute read

We develop a conservative Q-learning (CQL) algorithm, such that the expected value of a policy under the learned Q-function lower-bounds its true value. A lower bound on the Q-value prevents the over-estimation that is common in offline RL settings due to OOD actions and function approximation error. We start by focusing on policy evaluation step in CQL, which could be used by itself as an off-policy evaluation procedure, or integrated into a complete offline RL algorithm.

Conservative Off-Policy Evaluation

Concepts

  • $V^\pi(s)$ the value of a target policy $\pi$ in state $s$
  • offline dataset $D$ generated by behavior policy $\pi_\beta(a\vert s)$

First, we consider a penalty on a particular distribution of state-action pairs $\mu(s,a)$ to minimize the expected Q-value under it. Since for off-policy evaluation Q function training only queries unobserved action, but not state. So we restrict $\mu(s,a)$ to match the state-marginal in offline dataset $D$, which is $d^{\pi_\beta}(s)\mu(a\vert s)$. This gives rise to the iterative update for training the Q-function, as a function of a trade-off factor $\alpha$. \(\begin{aligned} \hat{Q}^{k+1} \leftarrow \argmin_{Q} \alpha \mathbb{E}_{s\sim D,a\sim \mu(a\vert s)}\left[Q(s,a)\right]+\frac{1}{2}\mathbb{E}_{s,a\sim D} \left[(Q(s,a)-\hat{\mathcal{B}}^{\pi}\hat{Q}^k(s,a))^2\right] \end{aligned}\)
In iterative above, the resulting Q-function $\hat{Q}^\pi:=lim_{k\rightarrow \infty}\hat{Q}^k$, lower-bounds $Q^{\pi}$ at all (s,a). However, we can substantially tighten this bound if we are only interested in estimating $V^\pi(s)$. If we only require that expected value of $Q^{\pi}$ under $\pi(a\vert s)$ lower-bound $V^\pi$, we can improve the bound by introducing an additional Q-value maximization term under the data distribution, $\pi_\beta(a\vert s)$, resulting in iterative update (changes in red): \(\begin{aligned} \hat{Q}^{k+1} \leftarrow \argmin_{Q} \alpha \left(\mathbb{E}_{s\sim D,a\sim \mu(a\vert s)}\left[Q(s,a)\right]-\mathbb{E}_{s\sim D, a\sim \hat{\pi}_{\beta}(a\vert s) [Q(s,a)]}\right)+\frac{1}{2}\mathbb{E}_{s,a\sim D} \left[(Q(s,a)-\hat{\mathcal{B}}^{\pi}\hat{Q}^k(s,a))^2\right] \end{aligned}\)
When $\mu(s\vert a)$ equals to $\pi(a\vert s)$, then the expected Q value under a given state is guaranteed to be a low-bound as true value function such that \(\mathbb{E}_{\pi(a\vert s)}[\hat{Q}^\pi (s,a)] \leq V^\pi (s)\). Point-wise lower bound is not available since actions under \(\hat{\pi}_{\beta}\) might be overestimated. However, the authors also prove that besides \(\hat{\pi}_{\beta}\), other policy distributions also fail to guarantee a point-wise lower-bound.
Theoretical analysis
First we use concentration properties of \(\hat{\pi}_{\beta}\) to control this error. Formally, for all $s, a \in D$, with probability \(\geq 1-\delta\), \(\vert \hat{\pi}_{\beta}-\pi_{\beta} \vert (s, a) \leq \frac{C_{r,T,\delta}}{\sqrt{\vert D(s,a)\vert }}\) , where $C_{r,T,\delta}$ is a constant dependent on the concentration properties (variance) of r(s, a) and $T(s^{‘}\vert s, a)$, and $\delta \in (0, 1)$, where $\frac{1}{D}$ is the square root inverse counts for each state-action pair. Now, we show that the conservative Q-function leaned by iterating Equation above, lower-bounds the true Q-function.

Theorem 1.1
For any \(\mu(a\vert s)\) with \(\texttt{supp} \mu \subset \texttt{supp} \hat{\pi}_\beta\) , with probability $\geq 1- \delta, \hat{Q}^\pi$ (the Q-function obtained by iterating Equation) satisfies: \(\forall s \in D, a, \hat{Q}^\pi(s,a) \leq Q^\pi(s,a)-\alpha \left[(I-\gamma P^\pi)^{-1}\frac{\mu}{\hat{\pi}_\beta}\right](s,a) + \left[(I-\gamma P^\pi)^{-1} \frac{C_{r,T,\delta}R_{max}}{(1-\gamma)\sqrt{\vert D\vert }}\right]\)

Thus, if $\alpha$ is sufficiently large, then \(\hat{Q}^\pi (s,a) \leq Q^\pi (s,a) \forall s\in D\), a. When \(\hat{\mathcal{B}}^\pi=\mathcal{B}^\pi\), any $\alpha > 0$ guarantees \(\hat{Q}^\pi(s,a), \forall s \in D, a\in A.\)

Next, we show that Equation lower-bounds the expected value under policy $\pi$, when $\mu=\pi$. Here we assume \(\frac{1}{\sqrt{\vert D\vert }}\) as a vector of inverse square root of only state counts, with a similar correction as state-action pairs before as to zero state counts. We also explain why this is not a pointwise lower-bound using update in the following theorem.

Theorem 1.2
(Q value Update results in a tighter lower bound.) The value of the policy under the Q-function from , \(\hat{V}^\pi(s)=E_\pi(a\vert s)[\hat{Q}^\pi (s,a)]\), lower-bounds the true value of the policy obtained via exact policy evaluation, obtained via exact policy evaluation, \(V^\pi(s)=E_\pi(a\vert s)[Q^\pi (s,a)]\), when \(\mu = \pi\), according to: \(\forall s\in D, \hat{V}^\pi(s)\leq V^\pi(s)-\alpha\left[(I-\gamma P^\pi)^{-1}\mathbb{E}_\pi\left[\frac{\pi}{\hat{\pi}_\beta}-1\right]\right](s)+\left[(I-\gamma P^\pi)^{-1}\frac{C_{r,T,\delta}R_{max}}{(1-\gamma)\sqrt{\vert D\vert }}\right](s)\) Thus, if \(\alpha > \frac{C_{r,T,\delta}R_{max}}{(1-\gamma)}\cdot max_{s\in D}\frac{1}{\sqrt{\vert D\vert }}\cdot \left[\sum_a \pi(a\vert s)(\frac{\pi(a\vert s)}{\hat{\pi}_\beta(a\vert s)}-1)\right]^{-1}, \forall s\in D, \hat{V}^{\pi}(s)\leq V^{\pi}(s)\), with probability \(\geq 1-\delta\). When \(\hat{\mathcal{B}}^{\pi}=\mathcal{B}^{\pi}\), then any $\alpha>0$ guarantees \(\hat{V}^{\pi}(s)\leq V^{\pi}(s), \forall s\in D\).

Please note that above analysis also applies to function approximation cases where Q-function is represented as a linear function approximate or non-linear neural network function approximate.

Conservative Q learning for offline RL

Now we obtain the Q value $\hat{Q}^\pi$ which lower-bounds the policy $\pi$’s value by solving with $\mu=\pi$. How should we utilize this for policy optimization? We could alternate between performing full off-policy evaluation for each policy iterate, $\hat{\pi}^k$, and one step of policy improvement. However, this can be computationally expensive. Alternatively, since the policy $\hat{\pi}^k$ is typically derived from the Q-function, we could instead choose $\mu(a\vert s)$ to approximate the policy that would maximize the current Q-function iterate, thus giving rise to an online algorithm, which is a min-max optimization problem: \(\begin{aligned} \label{CQLupdate} min_{Q}max_{\mu} \alpha(\mathbb{E}_{s\sim D,a\sim \mu(a\vert s)}&[Q(s,a)]-\mathbb{E}_{s\sim D,a\sim \hat{\pi}_\beta(a\vert s)}[Q(s,a)])\\ &+\frac{1}{2}\mathbb{E}_{s,a,s^{'}\sim D}\left[\left(Q(s,a)-\hat{\mathcal{B}}^{\pi_k}\hat{Q}^k(s,a)\right)^2\right]+\mathcal{R}(\mu) (CQL(\mathcal{R})). \end{aligned}\) To demonstrate the generality of the CQL family of optimization problems, we discuss two specific instances within this family that are of special interests. We assume $\mathcal{R}(\mu)$ to be the KL divergence against a prior distribution, $\rho(a\vert s)$, which is $\mathcal{R}(\mu)=-D_{KL}(\mu,\rho)$, then we get $\mu(a\vert s)\propto \rho(a\vert s)\cdot exp(Q(s,a))$. First, if $\rho = Unif(a)$, then the first term in corresponds to a soft-maximum of the Q values at any state $s$ and gives rise to the following varaint of , called $CQL(\mathcal{H})$: \(\begin{aligned} \label{CQLupdv1} min_{Q}\alpha\mathbb{E}_{s\sim D}\left[log\sum_{a}exp(Q(s,a))-\mathbb{E}_{a\sim \hat{\pi}_\beta(a\vert s)}[Q(s,a)]\right]+\frac{1}{2}\mathbb{E}_{s,a,s^{'}\sim D}\left[(Q-\hat{\mathcal{B}}^{\pi_k}\hat{Q}^k)^2\right] \end{aligned}\) If $\rho(a\vert s)$ is chosen to be previous policy iterate $\hat{\pi}^{k-1}$, the first term in Equation is replaced by an exponential weighted average of Q-values of actions from the chosen $\hat{\pi}^{k-1}(a\vert s)$. Empirically, we find that variant can be more stable with high-dimensional action spaces where it is challenging to estimate log$\sum_a exp$ via sampling due to high variance. In Appendix, we discuss an additional variant of CQL, drawing connections to distributionally robust optimization.

Theoretical analysis of CQL.
Next, we will theoretically analyze CQL to show that the policy updates derived in Equation are indeed “conservative”, in the sense that each successive policy iterate is optimized against a lower bound on its value. For clarity, we state the results in the absence of finite sample error, but sampling error can be incorporated in the same way as Theorems , and we discuss this in Appendix. Theorem below shows that any variant of the CQL family learns Q-value estimates that lower-bound the actual Q-function under the action-distribution defined by the policy, $\pi^k$, under mild regularity conditions (slow updates on the policy).

Theorem 2.1

(CQL learns lower-bounded Q-values). Let \(\pi_{\hat{Q}^k}(a\vert s)\propto exp(\hat{Q}^k(s,a))\) and assume that \(D_{TV}(\hat{\pi}^{k+1},\pi_{\hat{Q}^k})\leq \epsilon\) (i.e. $\hat{\pi}^{k+1}$ changes slowly w.r.t to $\hat{Q}^k$. Then, the policy value under $\hat{Q}^k$, lower-bounds the actual policy value, \(\hat{V}^{k+1}(s)\leq V^{k+1}(s) \forall s\) if \(\mathbb{E}_{\pi_{\hat{Q}^k}(a\vert s)}\left[\frac{\pi_{\hat{Q}^k}(a\vert s)}{\hat{\pi}_{\beta}(a\vert s)}-1\right] \geq max_{\bf{a}\ s.t. \hat{\pi}_{\beta}>0} \left(\frac{\pi_{\hat{Q}^k}(a\vert s)}{\hat{\pi}_{\beta}(a\vert s)}\right)\cdot \epsilon\) The LHS of this inequality is equal to the amount of conservatism induced in the value, $\hat{V}^{k+1}$ in iteration k + 1 of the CQL update, if the learned policy were equal to soft-optimal policy for $\hat{Q}^k$, i.e., when \(\hat{\pi}^{k+1} = \pi_{\hat{Q}^k}\). However, as the actual policy, $\hat{\pi}^{k+1}$, may be different, the RHS is the maximal amount of potential overestimation due to this difference. To get a lower bound, we require the amount of underestimation to be higher, which is obtained if $\epsilon$ is small, i.e. the policy changes slowly.

Our final result shows that CQL Q-function update is “gap-expanding”, by which we mean that the difference in Q-values at in-distribution actions and over-optimistically erroneous out-of-distribution actions is higher than the corresponding difference under the actual Q-function. This implies that the policy \(\pi^k(a\vert s)\propto exp(\hat{Q}^k(s,a))\), is constrained to be closer to the dataset distribution, $\pi^\beta(a\vert s)$, thus the CQL update implicitly prevents the detrimental effects of OOD action and distribution shift, which has been a major concern in offline RL settings.

Theorem 2.2

At any iteration k, CQL expands the difference in expected Q-values under the behavior policy $\pi_\beta(a\vert s)$ and $\mu_k$, such that for large enough values of $\alpha_k$, we have that \(\forall s, \mathbb{E}_{\pi_\beta(a\vert s)}[\hat{Q}^k(s,a)]-\mathbb{E}_{\mu_k(a\vert s)}[\hat{Q}^k(s,a)] > \mathbb{E}_{\pi_\beta(a\vert s)}[Q^k(s,a)]-\mathbb{E}_{\mu_k(a\vert s)}[Q^k(s,a)]\). When function approximation or sampling error makes OOD actions have higher learned Q-values, CQL backups are expected to be more robust, in that the policy is updated using Q-values that prefer in-distribution actions. As we will empirically show in Appendix , prior offline RL methods that do not explicitly constrain or regularize the Q-function may not enjoy such robustness properties.

Safe Policy Improvement Guarantees

In this section, we show that this Q-function update procedure (CQL) actually optimizes a well-defined objective and provide a safe policy improvement result. To begin with, we define empirical return of any policy $\pi, J(\pi, \hat{M})$, which is equal to the discounted return of a policy $\pi$ in the empirical MDP, $\hat{M }$, that is induced by the transitions observed in the dataset D, i.e. \(\hat{M} = \{s, a, r, s^{'} \in D\}\). $J(\pi, M)$ refers to the expected discounted return attained by a policy $\pi$ in the actual underlying MDP, M. In Theorem 3.1, we first show that CQL optimizes a well-defined penalized RL empirical objective. All proofs are found in Appendix.
Theorem 3.1

Let $\hat{Q}^\pi$ be the fixed point of Equation, then \(\pi^{*}(a\vert s):=\argmax_{\pi}\mathbb{E}_{s\sim \rho(s)}[\hat{V}^\pi(s)]\) is equivalently obtained by solving \(\begin{aligned} &\pi^{*}(a\vert s)\leftarrow \argmax_{\pi} J(\pi,\hat{M})-\alpha\frac{1}{1-\gamma}\mathbb{E}_{s\sim d^\pi_{\hat{M}}(s)}[D_{CQL}(\pi,\hat{\pi}_\beta)(s)],\\ &where D_{CQL}(\pi,\hat{\pi}_\beta)(s):=\sum_{a}\pi(a\vert s)\cdot \left(\frac{\pi(a\vert s)}{\pi_\beta(a\vert s)}-1\right). \end{aligned}\)
Intuitively, Theorem 3.1 says that CQL optimizes the return of a policy in the empirical MDP, $\hat{M}$, while also ensuring that the learned policy $\pi$ is not too different from the behavior policy, \(\hat{\pi}_\beta\) via a penalty that depends on $D_{CQL}$. This is because of the update procedure of Equation as well as its gap-expanding property. Build on this theorem and analysis of Constraint Policy Optimization, we show that CQL provides a $\zeta-safe$ policy improvement over \(\hat{\pi}_\beta\).
Theorem 3.2

Let $\pi^{*}(a\vert s)$ be the policy obtained by optimizing . Then, the policy $\pi^{*}(a\vert s)$ is a $\zeta$-safe policy improvement over \(\hat{\pi}_\beta\) in the actual MDP M, i.e., \(J(\pi^{\*}, M) \ge J(\hat{\pi}_\beta, M) - \zeta\) with high probability $1 - \delta$, where $\zeta$ is given by, \(\begin{aligned} \zeta = 2\left(\frac{C_{r,\delta}}{1-\gamma}+\frac{\gamma R_{max}C_{T,\delta}}{(1-\gamma)^2}\right)\mathbb{E}_{s\sim d^\pi_{\hat{M}}(s)}&\left[\frac{\sqrt{\vert A\vert }}{\sqrt{\vert D(s)\vert }}\sqrt{D_{CQL}(\pi,\hat{\pi}_\beta)(s)+1}\right]-\\ &(J(\pi^*,\hat{M})-J(\hat{\pi}_\beta,\hat{M}))\\ &\geq \alpha\frac{1}{1-\gamma}\mathbb{E}_{s\sim d^\pi_{\hat{M}}(s)}[D_{CQL}(\pi,\hat{\pi}_\beta)(s)] \end{aligned}\)
The expression of $\zeta$ in Theorem 3.2 consists of two terms: the first term captures the decrease in policy performance in M, that occurs due to the mismatch between $\hat{M}$ and M, also referred to as sampling error. The second term captures the increase in policy performance due to CQL in empirical MDP, $\hat{M}$. The policy $\pi^{*}$ obtained by optimizing $\pi$ against the CQL Q-function improves upon the behavior policy, $\hat{\pi}_\beta$ for suitably chosen values of $\alpha$. When sampling error is small, i.e., $\vert D(s)\vert $ is large, then smaller values of $\alpha$ are enough to provide an improvement over the behavior policy.