Variational inference: Ising models


Welcome to my first blog post! I’ll be populating this blog with (hopefully) short, bite-sized puzzles that I run into.

Today’s post is going to be about something near and dear to me: variational inference. I’ve always had a soft spot for probabilistic inference, even if it’s sometimes hard to justify due to the tension between expressivity of the model and the computational cost of inference1.

In this post, I will be going over the application of simple variational inference in a classical graphical model, an Ising model. I’m going to assume familiarity with variational inference, and speed through things a little quickly.

Introduction

Ising models have been used, before neural networks became the de-facto method for everything, for things like semantic segmentation in images. In its simplest form, the goal of semantic segmentation is to classify each individual pixel of an image as foreground or background.

Naturally, if a pixel is in the foreground of an image, then its neighbouring pixels are more likely to be foreground as well. Ising models have the ability to formalize this by assigning an affinity score to neighbouring pixels: the affinity score is high if the pixels take the same value (e.g. foreground) and low otherwise.

While Ising models are generally bad generative models2 and have almost completely fallen out of favor for all forms of image modeling, we will be tackling the inference problem for fun.

Problem setup

Consider a variant of an Ising model that only models the interaction between between foreground pixels. Depending on the location of the pixel, its neighboring pixels may either be encouraged or discouraged from also being in the foreground.

We can formalize this as a model over binary vectors, e.g. flattened image pixel labels x{0,1}nx \in \{0,1\}^n, where a value of 1 indicates xix_i is in the foreground. The joint distribution over xx is given by:

p(x)=exp(xTWx)Z,\begin{equation} p(x) = \frac{\exp(-x^TWx)}{Z}, \end{equation}

where the partition function is Z=xexp(xTWx)Z = \sum_x \exp(-x^TWx)3. The affinity matrix WRn×nW\in\R^{n\times n} determines the influence of pixel xix_i on pixel xjx_j.

Our goal is to approximate the log partition (cumulant) function, logZ\log Z. For notational convenience, we denote the potential function ϕ(x)=xTWx\phi(x) = -x^TWx, yielding

logp(x)=ϕ(x)logZ.\begin{equation} \log p(x) = \phi(x) - \log Z. \end{equation}

Variational lower bound

We can lower bound the cumulant function by starting with the KL between a variational distribution qq and pp:

KL[qp]=Eq(x)[logq(x)logp(x)]=Eq(x)[logq(x)]Eq(x)[ϕ(x)logZ]=H[q]Eq(x)[ϕ(x)]+logZ\begin{align} KL[q||p]&= E_{q(x)}[\log q(x) - \log p(x)]\\ &= E_{q(x)}[\log q(x)] - E_{q(x)}[\phi(x) - \log Z]\\ &= -H[q] - E_{q(x)}[\phi(x)] + \log Z \end{align}

Rearranging, we get

logZ=H[q]+Eq(x)[ϕ(x)]+KL[qp]H[q]+Eq(x)[ϕ(x)]=L\begin{align} \log Z &= H[q] + E_{q(x)}[\phi(x)] + KL[q||p]\\ &\ge H[q] + E_{q(x)}[\phi(x)] = \mathcal{L} \end{align}

by Gibbs inequality.

Mean parameterization

We assume the variational distribution is fully factored:

logq(x)=ilogqi(xi),\begin{equation} \log q(x) = \sum_i \log q_i(x_i), \end{equation}

with each qi(xi)=Bernoulli(μi)q_i(x_i) = \text{Bernoulli}(\mu_i). Our goal in this section is to rewrite the lower bound in terms of the variational mean parameters μ=(μ1,,μn)\mu = (\mu_1,\ldots,\mu_n). Writing down the bound in terms of the mean parameters will allow use to easily implement things in code.

The lower bound L\mathcal{L} is the sum of the entropy H[q]H[q] and the expected potentials Eq[ϕ(x)]E_{q}[\phi(x)].

The entropy can be expressed as

H[q]=iμilogμii(1μi)log(1μi).\begin{equation} H[q] = -\sum_i \mu_i\log\mu_i - \sum_i(1-\mu_i)\log(1-\mu_i). \end{equation}

The expected potentials can be expressed as

Eq(x)[ϕ(x)]=Eq[i,jxixjWij]=i,jWijEq(xi,xj)[xixj].\begin{align} E_{q(x)}[\phi(x)] &= E_{q}[-\sum_{i,j} x_ix_jW_{ij}]\\ &= -\sum_{i,j} W_{ij}E_{q(x_i,x_j)}[x_ix_j]. \end{align}

Note that Eq(xi,xj)[xixj]=Eq(xi)xiEq(xjxi)xjE_{q(x_i,x_j)}[x_ix_j] = E_{q(x_i)} x_i E_{q(x_j|x_i)} x_j, which is μiμj\mu_i\mu_j if iji \ne j and μi\mu_i if i=ji=j. Therefore, we have

Eq(x)[ϕ(x)]=ijμiμjWijiμiWii=i,jμiμjWijiμiWii+iμi2Wii.\begin{align} E_{q(x)}[\phi(x)] &= -\sum_{i\ne j}\mu_i\mu_jW_{ij} - \sum_{i}\mu_iW_{ii}\\ &= -\sum_{i,j} \mu_i\mu_jW_{ij} - \sum_{i} \mu_i W_{ii} + \sum_{i} \mu_i^2 W_{ii}. \end{align}

Implementation

The implementation of these two terms is relatively straightforward in torch. One thing we have to be careful of is that the mean parameters of qq, the μi\mu_i, are constrained to be in [0,1][0,1]. We can achieve that by projecting the parameters to the correct space by applying a sigmoid function every time they are needed.

import torch

class InferenceNetwork(torch.nn.Module):
    """ Fully factored inference network for Ising model.
        Parameterizes vector of Bernoulli means, mu, independently.
    """
    def __init__(self, W):
        super().__init__()
        dim = W.shape[0]
        self.W = W
        self.means = torch.nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def entropy(self):
        mu = self.means.sigmoid()
        complement = 1 - mu
        return -(
            (mu * mu.log()).sum()
            + (complement * complement.log()).sum()
        )

    def expected_potential(self):
        mu = self.means.sigmoid()
        quadratic = torch.einsum("i,j,ij->", mu, mu, self.W)
        mean = torch.einsum("i,i->", mu, self.W.diag())
        bias = torch.einsum("i,i->", mu**2, self.W.diag())
        return -quadratic - (mean - bias)

    def lowerbound(self):
        return self.entropy() + self.expected_potential()

Training

We can optimize the the lower bound to try to get the best qq^* with the closest approximation of logZ\log Z in our hypothesis class by directly optimizing the lower bound:

def fit(model, num_steps=100, lr=1e-2):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    for step in range(num_steps):
        optimizer.zero_grad()
        loss = -model.lowerbound()
        loss.backward()
        optimizer.step()

Conclusion (for now)

And that’s most of it! You can check out the full code here.


Footnotes

Footnotes

  1. In the past, I’ve tried to scale probabilistic models like hidden Markov models to limited success. The bigger the model, the more expensive it is to train! This is kind of obvious, but the tradeoff seems to be worse for models that maintain explicit representations of uncertainty than those that do not, e.g. neural networks.

  2. I posit that this is due to the inability to model long-range dependencies or low frequency features. However, Ising models could potentially model local dependences / high frequency features well. It’s possible that combining Ising models with diffusion models may lead to nice mix of capabilities.

  3. We try to mostly use the notation from the “Monster” paper by Wainwright and Jordan.