Bayesian Inference and Graphical ModelsExpectation-Maximization
EM for Gaussian mixture models
In this section, we'll develop an approach to estimating model parameters when some of the random variables involved in the model's Bayes net are not observed. We'll begin with the Gaussian mixture model and develop an intuitive version of the method, and then we'll introduce the general version and apply it to a hidden Markov model.
Recall that the Gaussian mixture model (GMM) is a Bayes net with just two random variables: a discrete random variable and a random vector . To draw an observation from this model, we draw from a given discrete distribution on , and then we draw from a multivariate normal distribution with mean and covariance .
Let's generate observations from a made-up GMM.
using Distributions, Plots, Random, LinearAlgebra, Statistics include("data-gymnasia/ellipse.jl") Random.seed!(123) n = 100 α = 0.4 𝒩₀ = MvNormal([1,1],[2.0 1.0; 1.0 2.0]) 𝒩₁ = MvNormal([3.0,7.0],[1.5 0; 0 0.5]) X₁ = zeros(n) X₂ = zeros(n) Z = zeros(Bool,n) for i=1:n Z[i] = rand(Bernoulli(α)) X₁[i], X₂[i] = Z[i] ? rand(𝒩₁) : rand(𝒩₀) end scatter(X₁, X₂, color = :gray, legend = false)
Let's think about how we could set about recovering the parameters of this model if all we had were the observations shown in the scatter plot.
One simple idea would be to write down the log likelihood of the data and hand it to an optimization algorithm to find parameters which maximize it. The problem is the lack of a reasonable way to write down the log likelihood when we have missing values in the Bayes net. Instead, let's develop an iterative approach that starts with a bad guess and works to improve it. We begin with arbitrary values for the parameters:
α = 0.6 μ₀ = [3.0,3.0] μ₁ = [1.0,6.0] Σ₀ = 1.0*Matrix(I, 2, 2) Σ₁ = 1.0*Matrix(I, 2, 2) mixtureplot(X₁,X₂,μ₀,Σ₀,μ₁,Σ₁)
Conceptually, we'd like to fit the blue distribution to the points in roughly the lower half of the figure, leaving the remaining upper points to be fit by the orange distribution. To this end, we come up with a score for each point indicating how much it seems to belong to the blue distribution or orange distribution, based on our current parameter estimates. More precisely, let's compute for each point its conditional probability of having (orange) given the value of the point.
We'll compute these values (conditional probability of being blue) for each point, and store the result in a vector called
Π = [α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂]) / ((1-α)*pdf(MvNormal(μ₀,Σ₀),[x₁,x₂]) + α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂])) for (x₁,x₂) in zip(X₁,X₂)]
We can visualize the result of this computation by actually coloring each point according to its blueness/orangeness value :
Next, we can fit a multivariate Gaussian to the points we colored blue. However, rather than performing the discontinuous operation of snapping each point to "orange" or "blue", we maintain the real-valued nature of the blueness/orangeness of each point, and instead compute a
α = sum(Π)/n μ₀ = [(1 .- Π) ⋅ X₁, (1 .- Π) ⋅ X₂] / sum(1 .- Π) μ₁ = [Π ⋅ X₁, Π ⋅ X₂] / sum(Π) Σ₀ = Matrix(Hermitian(sum((1-π)*([x₁,x₂] - μ₀) * ([x₁,x₂] - μ₀)' for (x₁,x₂,π) in zip(X₁,X₂,Π))/sum(1 .- Π))) Σ₁ = Matrix(Hermitian(sum(π*([x₁,x₂] - μ₁) * ([x₁,x₂] - μ₁)' for (x₁,x₂,π) in zip(X₁,X₂,Π))/sum(Π))) Π = [α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂]) / ((1-α)*pdf(MvNormal(μ₀,Σ₀),[x₁,x₂]) + α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂])) for (x₁,x₂) in zip(X₁,X₂)]; mixtureplot(X₁,X₂,μ₀,Σ₀,μ₁,Σ₁,Π)
If you run the cell above a few times, you'll see that it pretty quickly settles on a particular choice for the model parameters.
The General EM Algorithm
Now let's consider a general Bayesian network, with some variables hidden and others observed. How can we generalize the approach we developed for Gaussian mixture models?
The first step in our Gaussian mixture model algorithm was to compute for each data point the conditional distribution of the hidden variable given the observed data . That step is already general: for any collection of 's and 's, we can compute the conditional distribution of the 's given the 's.
The second step is trickier to generalize: for the GMM, we used the conditional probabilities for each value of as weights and chose model parameters which fit the data in a way that accounted for those weights. In the general case, we'll use the conditional distribution of the 's given the 's as a probability measure with respect to which we will compute the expected log likelihood and then choose model parameters which maximize that quantity.
More explicitly, we iterate the following steps to convergence:
Using current values for the parameters, work out the likelihood as a function of the observed values of the 's (which we will write using lowercase 's) as well as values for the 's which we'll pretend we also observed (we write these values as lowercase 's).
We will uppercase the 's to treat them as random variables, and we'll calculate the expectation of the log likelihood function with respect to the conditional distribution of the 's given the 's.
We maximize the expected log likelihood computed in the second step and update the parameter values to these new optimizing values.
Let's begin by showing that this algorithm is equivalent to the one we introduced previously in the Gaussian mixture case. Suppose . The likelihood of and is if and if (where is the Gaussian desity with mean and covariance , for ).
We can write this in a single expression by saying that the likelihood of and is . So the overall likelihood given all of the "observed" data is the product of these expressions:
Since the log of a product is a sum of logs, we get
for the log likelihood. Finally, taking the expectation, we can use linearity of expectation to look at one term at a time:
Note that the distribution of the random variable is supported at just two points (0 and 1). Therefore, we can pretty manageably compute this expectation by multiplying each possible value of the random variable by the probability that that value occurs.
Remember that the probability measure we decided to use for the 's is their conditional distribution given the specified 's. So let's define to be the conditional probability of given (for each from 1 to ). Then we get an expected log likelihood of
We can differentiate to minimize the terms involving the 's, and we find that works out to be the mean of the values (). The remaining terms (involving and ) take the form of a weighted maximum likelihood estimation problem for the normal distribution, and the optimizing parameters for that problem are the weighted sample mean and weighted sample covariance. (We derived this result in the statistics course in the case where the weights are uniform, and here we'll take the generalization of that calculation for granted).
The probability that we see is 1/2 (regardless of the value of ). Then the probability that we see given that is
Likewise, we get a similar factor for , another for , and so on. Putting these factors together, we get a likelihood of
accounting for all of the 's.
The conditional probability that we see a value of which is really close to (given the values) is proportional to , and similarly for and so on. All together, we get a likelihood of
Taking the log, we get
Finally, we replace the 's with 's (reflecting that we are going to treat them as random variables for purposes of computing an expected value) and take the expectation with respect to
Our goal is to find new values of and which
Therefore, the only remaining task is estimating , , and .
Give verbal descriptions of the quantities , , and .
Solution. We can describe as the expected number of times the Markov chain stays in the same state, as the expected number of Markov chain switches, and as the expected squared distance between the vector of 's and the vector of 's.
The main difficulty in estimating , , and is that the conditional distribution of the 's given the 's is a reasonably complex probability measure. It has to account for the values as well as the conditional distribution of each given the value of .
Fortunately, we've developed a technique for sampling from complex probability measures:
As a reminder, Metropolis-Hastings proceeds by starting from some point in the space we're trying to sample from. In this case, that means starting with a length- binary string. Then we propose changes to the string and accept or reject them with a probability which is determined by the density values at the current and proposed strings.
To work out the acceptance ratio, we need to compare the values of the desired density function for two paths which differ in one position.
In terms of and , work out the acceptance ratio for the proposal to move from to if .
Solution. We get
where denotes the Gaussian density with mean 0 and variance . We see that factors not involving the second position (where the change was proposed) cancel. So we're left with
Let's write down the result of this exercise in general terms. Including only the factors that don't necessarily cancel, we obtain
as the Metropolis-Hastings acceptance ratio.
To get more efficient mixing, we'll use a variation of Metropolis-Hastings where we cycle through the positions in order to propose changes, rather than choosing them randomly. This is called Gibbs sampling.
Lets take a look at several draws from the conditional distribution of given :
include("data-gymnasia/expectation-maximization.jl") observations = [plot(gibbs_sampler(X, (q, σ²)), yticks = 0:1) for _ in 1:10] plot(observations..., layout = (10, 1), size = (700, 700), legend = false)
We can use these draws from the conditional distribution of given to estimate , , and by counting the number of times each sample path switches states and accumulating the squared difference between and and each path . This method of estimating an expected value using draws from the underlying distribution is called
expectation-maximization.jl also contains a method for using these draws to estimate , , and :
We see that these estimates make sense: the Markov chain switches about 20% of the time, so we get an value of about 80% and a value of about 20%. Likewise, the variance used to generate these data was , so the accumulated squared difference across all values of is approximately .
Finally, we can actually perform the expectation-maximization algorithm using the update rules we derived:
q, σ² = em_algorithm(X)