GRaphical Models in Mallet

Graphical Models in GRMM

This document describes how to construct and perform inference on graphical models in GRMM. It will walk you through creating a model object, adding variables, calling an inference algorithm, and querying the resulting marginals.

If you want to train the parameters of the model, i.e., you want to learn a CRF, then you do not want to use this interface to GRMM. CRF training is handled in a different manner entirely. (This is because CRFs almost always have tied parameters, so the trainer needs to understand which parameters correspond to which factors in the model.) However, the CRF trainer does use this interface under the hood, so this document could be useful if you're trying to understand the CRF code.

An example of creating and performing inference is given in the class SimpleGraphExample, located in edu/umass/cs/mallet/grmm/examples/SimpleGraphExample.java. This document will go through that class a line at a time, explaining the API as we go.

There are three basic steps:

  1. Set up a factor graph.
  2. Call an inferencer.
  3. Collect the results.

STEP 1: Set up a factor graph

The main object for representing a graphical model is called FactorGraph. A factor graph knows about a series of Variable objects, each of which represent one random variable, and a series of Factor objects, which define the model.

We can create a series of Variables easily

      Variable[] allVars = {
        new Variable (2),
        new Variable (2),
        new Variable (2),
        new Variable (2)
      };

The 2 in the constructor specifies the number of outcomes of each variable. Here we have chosen to use binary variables. If you wish to specify the names of the outcomes of the random variables (this is useful for data input and printing), you can do this by passing in an Alphabet object into the Variable constructor. Note that random variables with continuous outcomes are not allowed in GRMM.

Now we can create the FactorGraph

      FactorGraph mdl = new FactorGraph (allVars);

and populate it with Factors. Factor is an abstract class, but the most useful subclasses are those that allow arbitrary tables, which are TableFactor and LogTableFactor. Creating the TableFactors is probably the most complicated part of specifying a model. This is because a table over k random variables is essentially an multidimensional array with k dimensions, and you have to make sure to get the order of the array indices right.

Here is an example of creating four pairwise factors with random values.

      // Create a diamond graph, with random potentials
      Random r = new Random (42);
      for (int i = 0; i < allVars.length; i++) {
        double[] ptlarr = new double [4];
        for (int j = 0; j < ptlarr.length; j++)
          ptlarr[j] = Math.abs (r.nextDouble ());

        Variable v1 = allVars[i];
        Variable v2 = allVars[(i + 1) % allVars.length];
        mdl.addFactor (v1, v2, ptlarr);
      }

The key step in the above is the call to mdl.addFactor, which takes two Variables and a double[], creates a TableFactor behind the scenes, and adds it to the model.

If you have a large number of variables, trying to get the order right is a bad idea. Then rather than creating a double[] array, you probably want to create a TableFactor directly, and then call its assignmentIterator method. This will give you a series of Assignment objects, one for each element in the factor's table. An Assignment is a mapping from variables to outcomes, so you don't have to worry about maintaining a 1-D representation of a multidimensional array.

STEP 2: Call an inferencer

GRMM contains support for several different inference algorithms, both exact and approximate. They all implement the Inferencer interface, so switching between them is straightforward. The most important types of inferencers are JunctionTreeInference (exact junction tree algorithm), LoopyBP (approximate "loopy" belief propagation), and GibbsSampler (approximate Gibbs sampling). All of these are used in roughly the same way, which is

      Inferencer inf = new JunctionTreeInferencer ();
      inf.computeMarginals (mdl);

This method runs the inferencer to compute the marginal distribution for each variable, and stores the results in the inferencer.

If your model is large, computeMarginals can take a long time. Exact inference in graphical models is exponential in the treewidth of the model, so if your model is too large, JunctionTreeInferencer is likely to run out of memory. Then you need to consider approximate inference.

Also, you might want to be aware that if you use pairwise factors only, inference time and space is quadratic in the number of outcomes a variable can have. So if your variables can have many outcomes, in the thousands, say, inference can also be slow. In this case, you want to think about beam search methods, which aren't implemented in GRMM.

STEP 3: Collect the marginals

Once you've run computeMarginals in an Inferencer, you can collect the marginals by calling the lookupMarginal method of the inferencer. This returns a Factor object that you can then query to find out what the marginals were. The easiest way to query a Factor object is using its assignmentIterator method. Here's how you can do that:

      for (int varnum = 0; varnum < allVars.length; varnum++) {
        Variable var = allVars[varnum];
        Factor ptl = inf.lookupMarginal (var);
        for (AssignmentIterator it = ptl.assignmentIterator (); it.hasNext ();) {
          int outcome = it.indexOfCurrentAssn ();
          System.out.println (var+"  "+outcome+"   "+ptl.value (it));
        }
        System.out.println ();
      }

Note that this just looks at single-variable marginals. Most inferencers compute marginals over pairs of variables as well. To get these, pick one of the Factor objects that lives in your factor graph. Then call the lookupMarginal(Factor) method on that factor.