Variational inference: Ising models
Welcome to my first blog post! I’ll be populating this blog with (hopefully) short, bite-sized puzzles that I run into.
Today’s post is going to be about something near and dear to me: variational inference. I’ve always had a soft spot for probabilistic inference, even if it’s sometimes hard to justify due to the tension between expressivity of the model and the computational cost of inference1.
In this post, I will be going over the application of simple variational inference in a classical graphical model, an Ising model. I’m going to assume familiarity with variational inference, and speed through things a little quickly.
Introduction
Ising models have been used, before neural networks became the de-facto method for everything, for things like semantic segmentation in images. In its simplest form, the goal of semantic segmentation is to classify each individual pixel of an image as foreground or background.
Naturally, if a pixel is in the foreground of an image, then its neighbouring pixels are more likely to be foreground as well. Ising models have the ability to formalize this by assigning an affinity score to neighbouring pixels: the affinity score is high if the pixels take the same value (e.g. foreground) and low otherwise.
While Ising models are generally bad generative models2 and have almost completely fallen out of favor for all forms of image modeling, we will be tackling the inference problem for fun.
Problem setup
Consider a variant of an Ising model that only models the interaction between between foreground pixels. Depending on the location of the pixel, its neighboring pixels may either be encouraged or discouraged from also being in the foreground.
We can formalize this as a model over binary vectors, e.g. flattened image pixel labels , where a value of 1 indicates is in the foreground. The joint distribution over is given by:
where the partition function is 3. The affinity matrix determines the influence of pixel on pixel .
Our goal is to approximate the log partition (cumulant) function, . For notational convenience, we denote the potential function , yielding
Variational lower bound
We can lower bound the cumulant function by starting with the KL between a variational distribution and :
Rearranging, we get
by Gibbs inequality.
Mean parameterization
We assume the variational distribution is fully factored:
with each . Our goal in this section is to rewrite the lower bound in terms of the variational mean parameters . Writing down the bound in terms of the mean parameters will allow use to easily implement things in code.
The lower bound is the sum of the entropy and the expected potentials .
The entropy can be expressed as
The expected potentials can be expressed as
Note that , which is if and if . Therefore, we have
Implementation
The implementation of these two terms is relatively straightforward in torch
.
One thing we have to be careful of is that the mean parameters of ,
the , are constrained to be in .
We can achieve that by projecting the parameters to the correct space
by applying a sigmoid function every time they are needed.
import torch
class InferenceNetwork(torch.nn.Module):
""" Fully factored inference network for Ising model.
Parameterizes vector of Bernoulli means, mu, independently.
"""
def __init__(self, W):
super().__init__()
dim = W.shape[0]
self.W = W
self.means = torch.nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def entropy(self):
mu = self.means.sigmoid()
complement = 1 - mu
return -(
(mu * mu.log()).sum()
+ (complement * complement.log()).sum()
)
def expected_potential(self):
mu = self.means.sigmoid()
quadratic = torch.einsum("i,j,ij->", mu, mu, self.W)
mean = torch.einsum("i,i->", mu, self.W.diag())
bias = torch.einsum("i,i->", mu**2, self.W.diag())
return -quadratic - (mean - bias)
def lowerbound(self):
return self.entropy() + self.expected_potential()
Training
We can optimize the the lower bound to try to get the best with the closest approximation of in our hypothesis class by directly optimizing the lower bound:
def fit(model, num_steps=100, lr=1e-2):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for step in range(num_steps):
optimizer.zero_grad()
loss = -model.lowerbound()
loss.backward()
optimizer.step()
Conclusion (for now)
And that’s most of it! You can check out the full code here.
Footnotes
Footnotes
-
In the past, I’ve tried to scale probabilistic models like hidden Markov models to limited success. The bigger the model, the more expensive it is to train! This is kind of obvious, but the tradeoff seems to be worse for models that maintain explicit representations of uncertainty than those that do not, e.g. neural networks. ↩
-
I posit that this is due to the inability to model long-range dependencies or low frequency features. However, Ising models could potentially model local dependences / high frequency features well. It’s possible that combining Ising models with diffusion models may lead to nice mix of capabilities. ↩
-
We try to mostly use the notation from the “Monster” paper by Wainwright and Jordan. ↩