cc.mallet.topics
Class ParallelTopicModel

java.lang.Object
  extended by cc.mallet.topics.ParallelTopicModel
All Implemented Interfaces:
java.io.Serializable

public class ParallelTopicModel
extends java.lang.Object
implements java.io.Serializable

Simple parallel threaded implementation of LDA, following Newman, Asuncion, Smyth and Welling, Distributed Algorithms for Topic Models JMLR (2009), with SparseLDA sampling scheme and data structure from Yao, Mimno and McCallum, Efficient Methods for Topic Model Inference on Streaming Document Collections, KDD (2009).

Author:
David Mimno, Andrew McCallum
See Also:
Serialized Form

Field Summary
 double[] alpha
           
 Alphabet alphabet
           
 double alphaSum
           
 double beta
           
 double betaSum
           
 int burninPeriod
           
 java.util.ArrayList<TopicAssignment> data
           
static double DEFAULT_BETA
           
 int[] docLengthCounts
           
 java.text.NumberFormat formatter
           
static java.util.logging.Logger logger
           
 java.lang.String modelFilename
           
 int numIterations
           
 int numTopics
           
 int numTypes
           
 int optimizeInterval
           
 boolean printLogLikelihood
           
 int randomSeed
           
 int saveModelInterval
           
 int saveSampleInterval
           
 int saveStateInterval
           
 int showTopicsInterval
           
 java.lang.String stateFilename
           
 int temperingInterval
           
 int[] tokensPerTopic
           
 LabelAlphabet topicAlphabet
           
 int topicBits
           
 int[][] topicDocCounts
           
 int topicMask
           
 int totalTokens
           
 int[][] typeTopicCounts
           
static int UNASSIGNED_TOPIC
           
 boolean usingSymmetricAlpha
           
 int wordsPerTopic
           
 
Constructor Summary
ParallelTopicModel(int numberOfTopics)
           
ParallelTopicModel(int numberOfTopics, double alphaSum, double beta)
           
ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta)
           
 
Method Summary
 void addInstances(InstanceList training)
           
 void buildInitialTypeTopicCounts()
           
 java.lang.String displayTopWords(int numWords, boolean usingNewLines)
           
 void estimate()
           
 Alphabet getAlphabet()
           
 java.util.ArrayList<TopicAssignment> getData()
           
 TopicInferencer getInferencer()
          Return a tool for estimating topic distributions for new documents
 int getNumTopics()
           
 MarginalProbEstimator getProbEstimator()
          Return a tool for evaluating the marginal probability of new documents under this model
 java.util.ArrayList<java.util.TreeSet<IDSorter>> getSortedWords()
          Return an array of sorted sets (one set per topic).
 LabelAlphabet getTopicAlphabet()
           
 double[] getTopicProbabilities(int instanceID)
          Get the smoothed distribution over topics for a training instance.
 double[] getTopicProbabilities(LabelSequence topics)
          Get the smoothed distribution over topics for a topic sequence, which may be from the training set or from a new instance with topics assigned by an inferencer.
 java.lang.Object[][] getTopWords(int numWords)
          Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order.
 void initializeFromState(java.io.File stateFile)
           
static void main(java.lang.String[] args)
           
 double modelLogLikelihood()
           
 void optimizeAlpha(WorkerRunnable[] runnables)
           
 void optimizeBeta(WorkerRunnable[] runnables)
           
 void printDocumentTopics(java.io.File file)
           
 void printDocumentTopics(java.io.PrintWriter out)
           
 void printDocumentTopics(java.io.PrintWriter out, double threshold, int max)
           
 void printState(java.io.File f)
           
 void printState(java.io.PrintStream out)
           
 void printTopicWordWeights(java.io.File file)
           
 void printTopicWordWeights(java.io.PrintWriter out)
          Print an unnormalized weight for every word in every topic.
 void printTopWords(java.io.File file, int numWords, boolean useNewLines)
           
 void printTopWords(java.io.PrintStream out, int numWords, boolean usingNewLines)
           
 void printTypeTopicCounts(java.io.File file)
          Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.
static ParallelTopicModel read(java.io.File f)
           
 void setBurninPeriod(int burninPeriod)
           
 void setNumIterations(int numIterations)
           
 void setNumThreads(int threads)
           
 void setOptimizeInterval(int interval)
          Interval for optimizing Dirichlet hyperparameters
 void setRandomSeed(int seed)
           
 void setSaveSerializedModel(int interval, java.lang.String filename)
          Define how often and where to save a serialized model.
 void setSaveState(int interval, java.lang.String filename)
          Define how often and where to save a text representation of the current state.
 void setSymmetricAlpha(boolean b)
           
 void setTemperingInterval(int interval)
           
 void setTopicDisplay(int interval, int n)
           
 void sumTypeTopicCounts(WorkerRunnable[] runnables)
           
 void temperAlpha(WorkerRunnable[] runnables)
           
 void topicPhraseXMLReport(java.io.PrintWriter out, int numWords)
           
 void topicXMLReport(java.io.PrintWriter out, int numWords)
           
 void write(java.io.File serializedModelFile)
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

UNASSIGNED_TOPIC

public static final int UNASSIGNED_TOPIC
See Also:
Constant Field Values

logger

public static java.util.logging.Logger logger

data

public java.util.ArrayList<TopicAssignment> data

alphabet

public Alphabet alphabet

topicAlphabet

public LabelAlphabet topicAlphabet

numTopics

public int numTopics

topicMask

public int topicMask

topicBits

public int topicBits

numTypes

public int numTypes

totalTokens

public int totalTokens

alpha

public double[] alpha

alphaSum

public double alphaSum

beta

public double beta

betaSum

public double betaSum

usingSymmetricAlpha

public boolean usingSymmetricAlpha

DEFAULT_BETA

public static final double DEFAULT_BETA
See Also:
Constant Field Values

typeTopicCounts

public int[][] typeTopicCounts

tokensPerTopic

public int[] tokensPerTopic

docLengthCounts

public int[] docLengthCounts

topicDocCounts

public int[][] topicDocCounts

numIterations

public int numIterations

burninPeriod

public int burninPeriod

saveSampleInterval

public int saveSampleInterval

optimizeInterval

public int optimizeInterval

temperingInterval

public int temperingInterval

showTopicsInterval

public int showTopicsInterval

wordsPerTopic

public int wordsPerTopic

saveStateInterval

public int saveStateInterval

stateFilename

public java.lang.String stateFilename

saveModelInterval

public int saveModelInterval

modelFilename

public java.lang.String modelFilename

randomSeed

public int randomSeed

formatter

public java.text.NumberFormat formatter

printLogLikelihood

public boolean printLogLikelihood
Constructor Detail

ParallelTopicModel

public ParallelTopicModel(int numberOfTopics)

ParallelTopicModel

public ParallelTopicModel(int numberOfTopics,
                          double alphaSum,
                          double beta)

ParallelTopicModel

public ParallelTopicModel(LabelAlphabet topicAlphabet,
                          double alphaSum,
                          double beta)
Method Detail

getAlphabet

public Alphabet getAlphabet()

getTopicAlphabet

public LabelAlphabet getTopicAlphabet()

getNumTopics

public int getNumTopics()

getData

public java.util.ArrayList<TopicAssignment> getData()

setNumIterations

public void setNumIterations(int numIterations)

setBurninPeriod

public void setBurninPeriod(int burninPeriod)

setTopicDisplay

public void setTopicDisplay(int interval,
                            int n)

setRandomSeed

public void setRandomSeed(int seed)

setOptimizeInterval

public void setOptimizeInterval(int interval)
Interval for optimizing Dirichlet hyperparameters


setSymmetricAlpha

public void setSymmetricAlpha(boolean b)

setTemperingInterval

public void setTemperingInterval(int interval)

setNumThreads

public void setNumThreads(int threads)

setSaveState

public void setSaveState(int interval,
                         java.lang.String filename)
Define how often and where to save a text representation of the current state. Files are GZipped.

Parameters:
interval - Save a copy of the state every interval iterations.
filename - Save the state to this file, with the iteration number as a suffix

setSaveSerializedModel

public void setSaveSerializedModel(int interval,
                                   java.lang.String filename)
Define how often and where to save a serialized model.

Parameters:
interval - Save a serialized model every interval iterations.
filename - Save to this file, with the iteration number as a suffix

addInstances

public void addInstances(InstanceList training)

initializeFromState

public void initializeFromState(java.io.File stateFile)
                         throws java.io.IOException
Throws:
java.io.IOException

buildInitialTypeTopicCounts

public void buildInitialTypeTopicCounts()

sumTypeTopicCounts

public void sumTypeTopicCounts(WorkerRunnable[] runnables)

optimizeAlpha

public void optimizeAlpha(WorkerRunnable[] runnables)

temperAlpha

public void temperAlpha(WorkerRunnable[] runnables)

optimizeBeta

public void optimizeBeta(WorkerRunnable[] runnables)

estimate

public void estimate()
              throws java.io.IOException
Throws:
java.io.IOException

printTopWords

public void printTopWords(java.io.File file,
                          int numWords,
                          boolean useNewLines)
                   throws java.io.IOException
Throws:
java.io.IOException

getSortedWords

public java.util.ArrayList<java.util.TreeSet<IDSorter>> getSortedWords()
Return an array of sorted sets (one set per topic). Each set contains IDSorter objects with integer keys into the alphabet. To get direct access to the Strings, use getTopWords().


getTopWords

public java.lang.Object[][] getTopWords(int numWords)
Return an array (one element for each topic) of arrays of words, which are the most probable words for that topic in descending order. These are returned as Objects, but will probably be Strings.

Parameters:
numWords - The maximum length of each topic's array of words (may be less).

printTopWords

public void printTopWords(java.io.PrintStream out,
                          int numWords,
                          boolean usingNewLines)

displayTopWords

public java.lang.String displayTopWords(int numWords,
                                        boolean usingNewLines)

topicXMLReport

public void topicXMLReport(java.io.PrintWriter out,
                           int numWords)

topicPhraseXMLReport

public void topicPhraseXMLReport(java.io.PrintWriter out,
                                 int numWords)

printTypeTopicCounts

public void printTypeTopicCounts(java.io.File file)
                          throws java.io.IOException
Write the internal representation of type-topic counts (count/topic pairs in descending order by count) to a file.

Throws:
java.io.IOException

printTopicWordWeights

public void printTopicWordWeights(java.io.File file)
                           throws java.io.IOException
Throws:
java.io.IOException

printTopicWordWeights

public void printTopicWordWeights(java.io.PrintWriter out)
                           throws java.io.IOException
Print an unnormalized weight for every word in every topic. Most of these will be equal to the smoothing parameter beta.

Throws:
java.io.IOException

getTopicProbabilities

public double[] getTopicProbabilities(int instanceID)
Get the smoothed distribution over topics for a training instance.


getTopicProbabilities

public double[] getTopicProbabilities(LabelSequence topics)
Get the smoothed distribution over topics for a topic sequence, which may be from the training set or from a new instance with topics assigned by an inferencer.


printDocumentTopics

public void printDocumentTopics(java.io.File file)
                         throws java.io.IOException
Throws:
java.io.IOException

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter out)

printDocumentTopics

public void printDocumentTopics(java.io.PrintWriter out,
                                double threshold,
                                int max)
Parameters:
out - A print writer
threshold - Only print topics with proportion greater than this number
max - Print no more than this many topics

printState

public void printState(java.io.File f)
                throws java.io.IOException
Throws:
java.io.IOException

printState

public void printState(java.io.PrintStream out)

modelLogLikelihood

public double modelLogLikelihood()

getInferencer

public TopicInferencer getInferencer()
Return a tool for estimating topic distributions for new documents


getProbEstimator

public MarginalProbEstimator getProbEstimator()
Return a tool for evaluating the marginal probability of new documents under this model


write

public void write(java.io.File serializedModelFile)

read

public static ParallelTopicModel read(java.io.File f)
                               throws java.lang.Exception
Throws:
java.lang.Exception

main

public static void main(java.lang.String[] args)