cc.mallet.types
Class Dirichlet

java.lang.Object
  extended by cc.mallet.types.Dirichlet

public class Dirichlet
extends java.lang.Object

Various useful functions related to Dirichlet distributions.

Author:
Andrew McCallum and David Mimno

Nested Class Summary
static class Dirichlet.Estimator
           
static class Dirichlet.MethodOfMomentsEstimator
           
 
Field Summary
static double DIGAMMA_COEF_1
           
static double DIGAMMA_COEF_10
           
static double DIGAMMA_COEF_2
           
static double DIGAMMA_COEF_3
           
static double DIGAMMA_COEF_4
           
static double DIGAMMA_COEF_5
           
static double DIGAMMA_COEF_6
           
static double DIGAMMA_COEF_7
           
static double DIGAMMA_COEF_8
           
static double DIGAMMA_COEF_9
           
static double DIGAMMA_LARGE
           
static double DIGAMMA_SMALL
           
static double EULER_MASCHERONI
          Actually the negative Euler-Mascheroni constant
static double HALF_LOG_TWO_PI
           
static double PI_SQUARED_OVER_SIX
           
 
Constructor Summary
Dirichlet(Alphabet dict)
          A symmetric Dirichlet with alpha_i = 1.0 and the number of dimensions of the given alphabet.
Dirichlet(Alphabet dict, double alpha)
          A symmetric Dirichlet with alpha_i = alpha and the number of dimensions of the given alphabet.
Dirichlet(double[] p)
          A dirichlet parameterized with a single vector of positive reals
Dirichlet(double[] alphas, Alphabet dict)
          Constructor that takes an alphabet representing the meaning of each dimension
Dirichlet(double m, double[] p)
          A dirichlet parameterized by a distribution and a magnitude
Dirichlet(int size)
          A symmetric Dirichlet with alpha_i = 1.0 and size dimensions
Dirichlet(int size, double alpha)
          A symmetric dirichlet: E(X_i) = E(X_j) for all i, j
 
Method Summary
 double absoluteDifference(Dirichlet other)
          Compute the L1 residual between two dirichlets
 double alpha(int featureIndex)
           
 void checkBreakeven(double x)
           
static java.lang.String compare(double sum, int k, int n, int w)
           
static double digamma(double z)
          Calculate digamma using an asymptotic expansion involving Bernoulli numbers.
static double digammaDifference(double x, int n)
           
 double dirichletMultinomialLikelihoodRatio(int[] countsX, int[] countsY)
          This version uses a non-symmetric Dirichlet prior
static double dirichletMultinomialLikelihoodRatio(int[] countsX, int[] countsY, double alpha, double alphaSum)
          What is the probability that these two observations were drawn from the same multinomial with symmetric Dirichlet prior alpha, relative to the probability that they were drawn from different multinomials both drawn from this Dirichlet?
static double dirichletMultinomialLikelihoodRatio(gnu.trove.TIntIntHashMap countsX, gnu.trove.TIntIntHashMap countsY, double alpha, double alphaSum)
          What is the probability that these two observations were drawn from the same multinomial with symmetric Dirichlet prior alpha, relative to the probability that they were drawn from different multinomials both drawn from this Dirichlet?
static java.lang.String distributionToString(double magnitude, double[] distribution)
          Create a printable list of alpha_i parameters
 int[] drawObservation(int n)
          Dirichlet-multinomial: draw a distribution from the dirichlet, then draw n samples from that multinomial.
 int[] drawObservation(int n, double[] distribution)
          Draw a count vector from the probability distribution provided.
 java.lang.Object[] drawObservations(int d, int n)
          Create a set of d draws from a dirichlet-multinomial, each with an average of n observations.
static double ewensLikelihoodRatio(int[] countsX, int[] countsY, double lambda)
          Similar to the Dirichlet-multinomial test,s this is a likelihood ratio based on the Ewens Sampling Formula, which can be considered the distribution of partitions of integers generated by the Chinese restaurant process.
 Alphabet getAlphabet()
           
static double learnParameters(double[] parameters, int[][] observations, int[] observationLengths)
          Learn Dirichlet parameters using frequency histograms
static double learnParameters(double[] parameters, int[][] observations, int[] observationLengths, double shape, double scale, int numIterations)
          Learn Dirichlet parameters using frequency histograms
 long learnParametersWithDigamma(int[][] binCounts, int[] observationLengths)
           
 long learnParametersWithDigamma(java.lang.Object[] observations)
          Use the fixed point iteration described by Tom Minka.
 long learnParametersWithHistogram(int[][] binCountHistograms, int[] lengthHistogram)
           
 long learnParametersWithHistogram(java.lang.Object[] observations)
          Use the fixed point iteration described by Tom Minka.
 long learnParametersWithLeaveOneOut(int[][] binCounts, int[] observationLengths)
          Learn parameters using Minka's Leave-One-Out (LOO) likelihood
 long learnParametersWithLeaveOneOut(java.lang.Object[] observations)
           
 long learnParametersWithMoments(java.lang.Object[] observations)
          Estimate a dirichlet with the moment matching method described by Ronning.
static double learnSymmetricConcentration(int[] countHistogram, int[] observationLengths, int numDimensions, double currentValue)
          Learn the concentration parameter of a symmetric Dirichlet using frequency histograms.
static double logGamma(double z)
          Currently aliased to logGammaStirling
static double logGammaDefinition(double z)
          This calculates a log gamma function exactly.
static double logGammaDifference(double z, int n)
          This directly calculates the difference between two log gamma functions using a recursive formula.
static double logGammaNemes(double z)
          Gergo Nemes' approximation
static double logGammaStirling(double z)
          Use a fifth order Stirling's approximation.
static void main(java.lang.String[] args)
           
 double[] nextDistribution()
           
 void print()
           
 Dirichlet randomDirichlet(Randoms r, double averageAlpha)
           
 FeatureSequence randomFeatureSequence(Randoms r, int length)
           
 FeatureVector randomFeatureVector(Randoms r, int size)
           
 Multinomial randomMultinomial(Randoms r)
           
protected  double[] randomRawMultinomial(Randoms r)
           
 TokenSequence randomTokenSequence(Randoms r, int length)
           
 double[] randomVector(Randoms r)
           
static void runComparison()
           
 int size()
           
 double squaredDifference(Dirichlet other)
          Compute the L2 residual between two dirichlets
static void testSymmetricConcentration(int numDimensions, int numObservations, int observationMeanLength)
           
 void toFile(java.lang.String filename)
          Write the parameters alpha_i to the specified file, one per line
static double trigamma(double z)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

EULER_MASCHERONI

public static final double EULER_MASCHERONI
Actually the negative Euler-Mascheroni constant

See Also:
Constant Field Values

PI_SQUARED_OVER_SIX

public static final double PI_SQUARED_OVER_SIX
See Also:
Constant Field Values

HALF_LOG_TWO_PI

public static final double HALF_LOG_TWO_PI

DIGAMMA_COEF_1

public static final double DIGAMMA_COEF_1
See Also:
Constant Field Values

DIGAMMA_COEF_2

public static final double DIGAMMA_COEF_2
See Also:
Constant Field Values

DIGAMMA_COEF_3

public static final double DIGAMMA_COEF_3
See Also:
Constant Field Values

DIGAMMA_COEF_4

public static final double DIGAMMA_COEF_4
See Also:
Constant Field Values

DIGAMMA_COEF_5

public static final double DIGAMMA_COEF_5
See Also:
Constant Field Values

DIGAMMA_COEF_6

public static final double DIGAMMA_COEF_6
See Also:
Constant Field Values

DIGAMMA_COEF_7

public static final double DIGAMMA_COEF_7
See Also:
Constant Field Values

DIGAMMA_COEF_8

public static final double DIGAMMA_COEF_8
See Also:
Constant Field Values

DIGAMMA_COEF_9

public static final double DIGAMMA_COEF_9
See Also:
Constant Field Values

DIGAMMA_COEF_10

public static final double DIGAMMA_COEF_10
See Also:
Constant Field Values

DIGAMMA_LARGE

public static final double DIGAMMA_LARGE
See Also:
Constant Field Values

DIGAMMA_SMALL

public static final double DIGAMMA_SMALL
See Also:
Constant Field Values
Constructor Detail

Dirichlet

public Dirichlet(double m,
                 double[] p)
A dirichlet parameterized by a distribution and a magnitude

Parameters:
m - The magnitude of the Dirichlet: sum_i alpha_i
p - A probability distribution: p_i = alpha_i / m

Dirichlet

public Dirichlet(double[] p)
A dirichlet parameterized with a single vector of positive reals


Dirichlet

public Dirichlet(double[] alphas,
                 Alphabet dict)
Constructor that takes an alphabet representing the meaning of each dimension


Dirichlet

public Dirichlet(Alphabet dict)
A symmetric Dirichlet with alpha_i = 1.0 and the number of dimensions of the given alphabet.


Dirichlet

public Dirichlet(Alphabet dict,
                 double alpha)
A symmetric Dirichlet with alpha_i = alpha and the number of dimensions of the given alphabet.


Dirichlet

public Dirichlet(int size)
A symmetric Dirichlet with alpha_i = 1.0 and size dimensions


Dirichlet

public Dirichlet(int size,
                 double alpha)
A symmetric dirichlet: E(X_i) = E(X_j) for all i, j

Parameters:
n - The number of dimensions
alpha - The parameter for each dimension
Method Detail

nextDistribution

public double[] nextDistribution()

distributionToString

public static java.lang.String distributionToString(double magnitude,
                                                    double[] distribution)
Create a printable list of alpha_i parameters


toFile

public void toFile(java.lang.String filename)
            throws java.io.IOException
Write the parameters alpha_i to the specified file, one per line

Throws:
java.io.IOException

drawObservation

public int[] drawObservation(int n)
Dirichlet-multinomial: draw a distribution from the dirichlet, then draw n samples from that multinomial.


drawObservation

public int[] drawObservation(int n,
                             double[] distribution)
Draw a count vector from the probability distribution provided.

Parameters:
n - The expected total number of counts in the returned vector. The actual number is ~ Poisson(n)

drawObservations

public java.lang.Object[] drawObservations(int d,
                                           int n)
Create a set of d draws from a dirichlet-multinomial, each with an average of n observations.


logGammaDefinition

public static double logGammaDefinition(double z)
This calculates a log gamma function exactly. It's extremely inefficient -- use this for comparison only.


logGammaDifference

public static double logGammaDifference(double z,
                                        int n)
This directly calculates the difference between two log gamma functions using a recursive formula. The break-even with the Stirling approximation is about n=2, so it's not necessarily worth using this.


logGamma

public static double logGamma(double z)
Currently aliased to logGammaStirling


logGammaStirling

public static double logGammaStirling(double z)
Use a fifth order Stirling's approximation.

Parameters:
z - Note that Stirling's approximation is increasingly unstable as z approaches 0. If z is less than 2, we shift it up, calculate the approximation, and then shift the answer back down.

logGammaNemes

public static double logGammaNemes(double z)
Gergo Nemes' approximation


digamma

public static double digamma(double z)
Calculate digamma using an asymptotic expansion involving Bernoulli numbers.


digammaDifference

public static double digammaDifference(double x,
                                       int n)

trigamma

public static double trigamma(double z)

learnSymmetricConcentration

public static double learnSymmetricConcentration(int[] countHistogram,
                                                 int[] observationLengths,
                                                 int numDimensions,
                                                 double currentValue)
Learn the concentration parameter of a symmetric Dirichlet using frequency histograms. Since all parameters are the same, we only need to keep track of the number of observation/dimension pairs with count N

Parameters:
countHistogram - An array of frequencies. If the matrix X represents observations such that xdt is how many times word t occurs in document d, countHistogram[3] is the total number of cells in any column that equal 3.
observationLengths - A histogram of sample lengths, for example observationLengths[20] could be the number of documents that are exactly 20 tokens long.
numDimensions - The total number of dimensions.
currentValue - An initial starting value.

testSymmetricConcentration

public static void testSymmetricConcentration(int numDimensions,
                                              int numObservations,
                                              int observationMeanLength)

learnParameters

public static double learnParameters(double[] parameters,
                                     int[][] observations,
                                     int[] observationLengths)
Learn Dirichlet parameters using frequency histograms

Parameters:
parameters - A reference to the current values of the parameters, which will be updated in place
observations - An array of count histograms. observations[10][3] could be the number of documents that contain exactly 3 tokens of word type 10.
observationLengths - A histogram of sample lengths, for example observationLengths[20] could be the number of documents that are exactly 20 tokens long.

learnParameters

public static double learnParameters(double[] parameters,
                                     int[][] observations,
                                     int[] observationLengths,
                                     double shape,
                                     double scale,
                                     int numIterations)
Learn Dirichlet parameters using frequency histograms

Parameters:
parameters - A reference to the current values of the parameters, which will be updated in place
observations - An array of count histograms. observations[10][3] could be the number of documents that contain exactly 3 tokens of word type 10.
observationLengths - A histogram of sample lengths, for example observationLengths[20] could be the number of documents that are exactly 20 tokens long.
shape - Gamma prior E(X) = shape * scale, var(X) = shape * scale2
scale -
numIterations - 200 to 1000 generally insures convergence, but 1-5 is often enough to step in the right direction

learnParametersWithHistogram

public long learnParametersWithHistogram(java.lang.Object[] observations)
Use the fixed point iteration described by Tom Minka.


learnParametersWithHistogram

public long learnParametersWithHistogram(int[][] binCountHistograms,
                                         int[] lengthHistogram)

learnParametersWithDigamma

public long learnParametersWithDigamma(java.lang.Object[] observations)
Use the fixed point iteration described by Tom Minka.


learnParametersWithDigamma

public long learnParametersWithDigamma(int[][] binCounts,
                                       int[] observationLengths)

learnParametersWithMoments

public long learnParametersWithMoments(java.lang.Object[] observations)
Estimate a dirichlet with the moment matching method described by Ronning.


learnParametersWithLeaveOneOut

public long learnParametersWithLeaveOneOut(java.lang.Object[] observations)

learnParametersWithLeaveOneOut

public long learnParametersWithLeaveOneOut(int[][] binCounts,
                                           int[] observationLengths)
Learn parameters using Minka's Leave-One-Out (LOO) likelihood


absoluteDifference

public double absoluteDifference(Dirichlet other)
Compute the L1 residual between two dirichlets


squaredDifference

public double squaredDifference(Dirichlet other)
Compute the L2 residual between two dirichlets


checkBreakeven

public void checkBreakeven(double x)

compare

public static java.lang.String compare(double sum,
                                       int k,
                                       int n,
                                       int w)

dirichletMultinomialLikelihoodRatio

public static double dirichletMultinomialLikelihoodRatio(gnu.trove.TIntIntHashMap countsX,
                                                         gnu.trove.TIntIntHashMap countsY,
                                                         double alpha,
                                                         double alphaSum)
What is the probability that these two observations were drawn from the same multinomial with symmetric Dirichlet prior alpha, relative to the probability that they were drawn from different multinomials both drawn from this Dirichlet?


dirichletMultinomialLikelihoodRatio

public static double dirichletMultinomialLikelihoodRatio(int[] countsX,
                                                         int[] countsY,
                                                         double alpha,
                                                         double alphaSum)
What is the probability that these two observations were drawn from the same multinomial with symmetric Dirichlet prior alpha, relative to the probability that they were drawn from different multinomials both drawn from this Dirichlet?


dirichletMultinomialLikelihoodRatio

public double dirichletMultinomialLikelihoodRatio(int[] countsX,
                                                  int[] countsY)
This version uses a non-symmetric Dirichlet prior


ewensLikelihoodRatio

public static double ewensLikelihoodRatio(int[] countsX,
                                          int[] countsY,
                                          double lambda)
Similar to the Dirichlet-multinomial test,s this is a likelihood ratio based on the Ewens Sampling Formula, which can be considered the distribution of partitions of integers generated by the Chinese restaurant process.


runComparison

public static void runComparison()

main

public static void main(java.lang.String[] args)

getAlphabet

public Alphabet getAlphabet()

size

public int size()

alpha

public double alpha(int featureIndex)

print

public void print()

randomRawMultinomial

protected double[] randomRawMultinomial(Randoms r)

randomMultinomial

public Multinomial randomMultinomial(Randoms r)

randomDirichlet

public Dirichlet randomDirichlet(Randoms r,
                                 double averageAlpha)

randomFeatureSequence

public FeatureSequence randomFeatureSequence(Randoms r,
                                             int length)

randomFeatureVector

public FeatureVector randomFeatureVector(Randoms r,
                                         int size)

randomTokenSequence

public TokenSequence randomTokenSequence(Randoms r,
                                         int length)

randomVector

public double[] randomVector(Randoms r)