Efficient Kullback-Leibler divergence of two Gaussian distributions in R

Author

Erik-Jan van Kesteren

Introduction

In this document, I create a small R function for computing the Kullback-Leibler divergence between two (multivariate) normal (Gaussian) distributions. The general formula for the divergence is as follows:

D_{KL}(P \parallel Q) = \int_{-\infty}^{\infty} p(x) \log \left(\frac{p(x)}{q(x)}\right)\, dx If we assume that both p(x) and q(x) are multivariate normal, then all the difficult integration drops out and we are left with the following formula, adapted from this excellent answer on StackOverflow:

D_{KL}(P \parallel Q) = \frac{1}{2}\left[\log|\Sigma_q|-\log|\Sigma_p| - d + \text{tr} \{ \Sigma_q^{-1}\Sigma_p \} + (\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)\right] Where d indicates the dimensionality of the distributions, and |\cdot| indicates the determinant.

Of course, a function already exists for computing this divergence in R. Specifically, it exists as rags2ridges::KLdiv(). However, when installing the {rags2ridges} package, several dependencies need to be pulled from bioconductor as they are not on CRAN, and the package in general has a lot of functionality we might not need if we just want to compute the KL-divergence.

The goal of this post is thus to create a light-weight, performant version of the KL-divergence function which is fast, has low memory requirements, and no dependencies beyond what’s included in base R. Below, I show how I achieved a big speedup compared to the existing KLdiv function.

Example inputs

We will be using the following example inputs:

set.seed(45)

# we use 25-dimensional normal distributions
P <- 25

# mean vectors
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)

# covariance matrices
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

The existing function returns the following value for D_{KL}(P \parallel Q) with these inputs (note that the ordering is a bit weird; KL-divergence is not symmetric so order is important):

rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] 15.40342

Computing log-determinants

The first two terms in our equation are the log-determinants of the covariance matrices:

\log|\Sigma_q|-\log|\Sigma_p|

There are several ways to compute a log-determinant in R; here, we benchmark three methods, once on the 25-dimensional S_1 we created earlier, and another time on a random 1000-dimensional covariance matrix:

bench::mark(
  logdet = log(det(S_1)),
  determ = determinant(S_1)$modulus,
  cholesky = 2*sum(log(diag(chol(S_1)))),
  check = FALSE
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 logdet       13.1µs   14.2µs    61407.    4.93KB    12.3 
2 determ       10.9µs   11.7µs    76030.    4.93KB     7.60
3 cholesky     14.6µs   15.6µs    57334.   11.98KB    11.5 
set.seed(45)
S1000 <- rWishart(1, 1000, diag(1000))[,,1]
bench::mark(
  logdet = log(det(S1000)),
  determ = determinant(S1000)$modulus,
  cholesky = 2*sum(log(diag(chol(S1000)))),
  check = FALSE
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 logdet        178ms    184ms      5.43    7.63MB     2.72
2 determ        171ms    186ms      5.37    7.63MB     2.68
3 cholesky      138ms    144ms      6.92    7.65MB     2.31

Generally, determinant() performs slightly better than log(det()). For low-dimensional distributions, the cholesky version is worst, but it is actually fastest for the 1000-dimensional distribution. This is especially interesting as we may be able to reuse this cholesky decomposition to speed up later steps.

Computing the trace

The second difficult term to compute is the trace part:

\text{tr} \{ \Sigma_q^{-1}\Sigma_p \}

Note that here we need the inverse of \Sigma_q, which we will need at a later stage as well. Here, we again have several options for computing the trace. One trick is adapted from the lavaan package, from the code here, with the additional knowledge that covariance matrices are symmetric.

bench::mark(
  naive = sum(diag(solve(S_1) %*% S_0)),
  solve = sum(diag(solve(S_1, S_0))),
  lavaan = sum(solve(S_1) * S_0)
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive        44.5µs   47.2µs    19589.      21KB     17.1
2 solve        32.4µs     35µs    26598.    11.2KB     13.3
3 lavaan       33.9µs   36.2µs    25071.    20.6KB     22.6

Here, the naive method is definitely the worst, and the other two methods are almost at the same level. Again, the last identity is especially interesting, because we want to precompute \Sigma_q^{-1} anyway for the next part.

Note that if we assume that we have the cholesky decomposition of \Sigma_q precomputed from the previous part, this equation – especially the last one – becomes even faster:

S_1_c <- chol(S_1)

bench::mark(
  naive = sum(diag(chol2inv(S_1_c) %*% S_0)),
  solve = sum(diag(forwardsolve(S_1_c, backsolve(S_1_c, S_0, transpose = TRUE), upper.tri = TRUE))),
  lavaan = sum(chol2inv(S_1_c) * S_0)
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive        16.8µs     18µs    50526.   10.34KB     20.2
2 solve        24.1µs   25.8µs    34050.      26KB     13.6
3 lavaan        7.1µs    7.4µs   116485.    9.86KB     35.0

Quadratic form

The last computationally intensive part of the equation is the following quadratic form:

(\mu_q - \mu_p)^T \Sigma_q^{-1}(\mu_q - \mu_p)

Here, we first compute the mean differences, which we call delta. Then, we again need the inverse of the covariance matrix of q(x), so we will assume that it is pre-computed from the previous step.

delta <- mean_1 - mean_0
Omega_1 <- chol2inv(S_1_c)

bench::mark(
  naive = t(delta) %*% Omega_1 %*% delta,
  cross = crossprod(delta, Omega_1 %*% delta),
  cross2 = crossprod(delta, crossprod(Omega_1, delta))
)
# A tibble: 3 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 naive         4.6µs    4.9µs   173750.      496B        0
2 cross         1.5µs    1.7µs   500959.      248B        0
3 cross2        1.9µs    2.1µs   441552.      248B        0

The second version, using only one crossprod() is the fastest here. It’s actually quite significantly faster than the naïve implementation with standard matrix products.

Putting it all together

Combining all the above sections, we can make two versions of the KL-divergence function. One which precomputes and uses the Cholesky decomposition for \Sigma_q, and one which does not. We can compare them for our example distributions:

kldiv_base <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  logdet_p <- determinant(Sigma_p)$modulus
  logdet_q <- determinant(Sigma_q)$modulus
  d <- length(mu_p)
  Omega_q <- solve(Sigma_q)
  trace <- sum(Omega_q * Sigma_p)
  delta <- mu_q - mu_p
  quad <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}

kldiv_chol <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  chol_q <- chol(Sigma_q)
  logdet_p <- determinant(Sigma_p)$modulus
  logdet_q <- 2*sum(log(diag(chol_q)))
  d <- length(mu_p)
  Omega_q <- chol2inv(chol_q)
  trace <- sum(Omega_q * Sigma_p)
  delta <- mu_q - mu_p
  quad <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}

bench::mark(
  base = kldiv_base(mean_0, mean_1, S_0, S_1),
  chol = kldiv_chol(mean_0, mean_1, S_0, S_1)
)
# A tibble: 2 × 6
  expression      min   median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl>
1 base         60.7µs     66µs    13807.     100KB     10.4
2 chol         38.8µs   42.3µs    20974.     103KB     12.6

As you can see, the cholesky version is quite a bit (>1/3) faster. But does it also give the correct answer?

kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 15.40342

It does! So let’s use that version to compare to the original, existing rags2ridges::KLdiv() function.

bench::mark(
  ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
  theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0),
  relative = TRUE
)
# A tibble: 2 × 6
  expression   min median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <dbl>  <dbl>     <dbl>     <dbl>    <dbl>
1 ours         1      1        13.0      1        2.22
2 theirs      13.8   13.7       1        5.12     1   

We’ve achieved quite a dramatic speedup, between 14 and 15x. Additionally, we have allocated 5x less memory. This speedup is maintained for 2-dimensional distributions:

# generate parameters
set.seed(45)
P <- 2
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

bench::mark(
  ours = kldiv_chol(mean_0, mean_1, S_0, S_1),
  theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0),
  relative = TRUE
)
# A tibble: 2 × 6
  expression   min median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <dbl>  <dbl>     <dbl>     <dbl>    <dbl>
1 ours         1      1        15.6       NaN     1.28
2 theirs      16.0   15.8       1         NaN     1   

And a positive side-effect of the new function and its memory efficiency is that it also works for larger 250-dimensional distributions, unlike the old function:

# generate parameters
set.seed(45)
P <- 250
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

kldiv_chol(mean_0, mean_1, S_0, S_1)
[1] 128.2745
rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0)
[1] NaN

Clean-up and document

What remains is to tidy up the function, make it more robust, and to document it properly. Allow me to present the fast and furious Gaussian Kullback-Leibler divergence:

#' Kullback-Leibler divergence between two Gaussians
#' 
#' This function computes $D_{KL}(p \parallel q)$, where $p(x)$ 
#' and $q(x)$ are two multivariate normal distributions.
#' 
#' @param mu_p the mean vector of p
#' @param mu_q the mean vector of q
#' @param Sigma_p the covariance matrix of p
#' @param Sigma_q the covariance matrix of q
#' 
#' @return Kullback-leibler divergence from p to q 
#' (numeric scalar)
kldiv <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  chol_q   <- chol(Sigma_q)
  logdet_p <- determinant(matrix(Sigma_p))$modulus
  logdet_q <- 2*sum(log(diag(chol_q)))
  d        <- length(mu_p)
  Omega_q  <- chol2inv(chol_q)
  trace    <- sum(Omega_q * Sigma_p)
  delta    <- mu_q - mu_p
  quad     <- crossprod(delta, Omega_q %*% delta)
  return(c(logdet_q - logdet_p - d + trace + quad) / 2)
}

Symmetric version

There is also a symmetric version of the Kullback-Leibler divergence, also known as Jeffreys divergence:

D_{J}(P, Q) = D_{KL}(P \parallel Q) + D_{KL}(Q \parallel P)

We can build on the speedups from before by also precomputing the Cholesky decomposition for \Sigma_p. There are some additional properties of the symmetric divergence which are nice for computation:

  • the log-determinant terms cancel out, so they do not need to be computed at all
  • we can take d out of the inner sum because it enters the equation twice (then we don’t need to multiply and divide by 2)
  • Because \text{Tr}\{A\}+\text{Tr}\{B\} = \text{Tr}\{A+B\} we can compute the trace at once
  • Because delta is used only in quadratic forms, we don’t need to compute it twice (\delta^T\delta = (-\delta)^T(-\delta)
  • Then, we can combine the quadratic forms as follows: x^TAx + x^TBx = x(A+B)x^T
#' Jeffreys divergence between two Gaussians
#' 
#' This function computes Jeffres divergence, the symmetric 
#' version of the Kullback-Leibler divergence: 
#' $D_{J}(p, q) = D_{KL}(P \parallel Q) + D_{KL}(Q \parallel P)$, 
#' where $p(x)$ and $q(x)$ are two multivariate normal distributions.
#' 
#' @param mu_p the mean vector of p
#' @param mu_q the mean vector of q
#' @param Sigma_p the covariance matrix of p
#' @param Sigma_q the covariance matrix of q
#' 
#' @return Kullback-leibler divergence from p to q 
#' (numeric scalar)
jefdiv <- function(mu_p, mu_q, Sigma_p, Sigma_q) {
  chol_p   <- chol(Sigma_p)
  chol_q   <- chol(Sigma_q)
  d        <- length(mu_p)
  Omega_p  <- chol2inv(chol_p)
  Omega_q  <- chol2inv(chol_q)
  trace    <- sum(Omega_p * Sigma_q + Omega_q * Sigma_p)
  delta    <- mu_q - mu_p
  quad     <- delta %*% (Omega_p + Omega_q) %*% delta
  return(c(trace + quad) / 2 - d)
}

When comparing the performance to that of the {rags2ridges} package, we see that the speed-up is slightly bigger than with the assymmetric KL-divergence, around 17x.

set.seed(45)
P <- 25
mean_0 <- rnorm(P)
mean_1 <- rnorm(P)
S_0 <- rWishart(1, P*2, diag(P))[,,1]
S_1 <- rWishart(1, P*2, diag(P))[,,1]

bench::mark(
  ours = jefdiv(mean_0, mean_1, S_0, S_1),
  theirs = rags2ridges::KLdiv(mean_1, mean_0, S_1, S_0, TRUE), 
  relative = TRUE
)
# A tibble: 2 × 6
  expression   min median `itr/sec` mem_alloc `gc/sec`
  <bch:expr> <dbl>  <dbl>     <dbl>     <dbl>    <dbl>
1 ours         1      1        16.4      1        2.85
2 theirs      17.2   16.6       1        1.42     1   

Julia version

We can compare this to the notoriously fast Julia programming language. The implementation there needs less customization and can proceed more or less exactly equal to the mathematical formula. Thus, we “simply” write out the function as follows:

using MKL
using LinearAlgebra
using BenchmarkTools

# kldiv function
function kl_div(μ_p::Vector, μ_q::Vector, Σ_p::AbstractMatrix, Σ_q::AbstractMatrix)
    d   = length(μ_p)
    Ω_q = inv(Σ_q)
    tr  = sum(Ω_q .* Σ_p)
    δ   = μ_q .- μ_p
    qfm = dot(δ, Ω_q, δ)

    return (logdet(Σ_q) - logdet(Σ_p) - d + tr + qfm) / 2
end
kl_div (generic function with 1 method)

Let’s check that it has the same output:

# get example inputs from R
mean_0 = @rget mean_0;
mean_1 = @rget mean_1;
S_0 = @rget S_0;
S_1 = @rget S_1;

kl_div(mean_0, mean_1, S_0, S_1)
15.403424894704695

Now, let’s benchmark this function:

@benchmark kl_div($mean_0, $mean_1, $S_0, $S_1)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  17.400 μs …  3.772 ms  ┊ GC (min … max): 0.00% … 95.55%
 Time  (median):     19.600 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   26.648 μs ± 89.504 μs  ┊ GC (mean ± σ):  8.68% ±  2.60%

  ▆██▆▄▂▁         ▁▂▂▂▁                ▁▂▃▂▂                  ▂
  ████████▅▅▃▁▁▁▄▇██████▇▆▆▅▅▅▅▆▄▅▅▁▃▆▄███████▇▆▅▅▄▄▃▅▅▆▇▅▆▆▆ █
  17.4 μs      Histogram: log(frequency) by time      77.2 μs <

 Memory estimate: 33.88 KiB, allocs estimate: 10.

The Julia function is even faster than the optimized R function, completing in less than 20 microseconds (median), whereas the median of the R Cholesky version (Section 6) is around 40 microseconds.