Home » How to Calculate KL Divergence in Python (Including Example)

How to Calculate KL Divergence in Python (Including Example)

by Erma Khan

In statistics, the Kullback–Leibler (KL) divergence is a distance metric that quantifies the difference between two probability distributions.

If we have two probability distributions, P and Q, we typically write the KL divergence using the notation KL(P || Q), which means “P’s divergence from Q.”

We calculate it using the following formula:

KL(P || Q) = ΣP(x) ln(P(x) / Q(x))

If the KL divergence between two distributions is zero, then it indicates that the distributions are identical.

We can use the scipy.special.rel_entr() function to calculate the KL divergence between two probability distributions in Python.

The following example shows how to use this function in practice.

Example: Calculating KL Divergence in Python

Suppose we have the following two probability distributions in Python:

Note: It’s important that the probabilities for each distribution sum to one.

#define two probability distributions
P = [.05, .1, .2, .05, .15, .25, .08, .12]
Q = [.3, .1, .2, .1, .1, .02, .08, .1]

We can use the following code to calculate the KL divergence between the two distributions:

from scipy.special import rel_entr

#calculate (P || Q)
sum(rel_entr(P, Q))

0.589885181619163

The KL divergence of distribution P from distribution Q is about 0.589.

Note that the units used in this calculation are known as nats, which is short for natural unit of information.

Thus, we would say that the KL divergence is 0.589 nats.

Also note that the KL divergence is not a symmetric metric. This means that if we calculate the KL divergence of distribution Q from distribution P, we will likely get a different value:

from scipy.special import rel_entr

#calculate (Q || P)
sum(rel_entr(Q, P))

0.497549319448034

The KL divergence of distribution Q from distribution P is about 0.497 nats.

Note: Some formulas use log base-2 to calculate the KL divergence. In this case, we refer to the divergence in terms of bits instead of nats.

Additional Resources

The following tutorials explain how to perform other common operations in Python:

How to Create a Correlation Matrix in Python
How to Create a Covariance Matrix in Python

Related Posts