/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.TopicAssignment;
import cc.mallet.topics.WorkerRunnable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.GZIPOutputStream;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ParallelTopicModel {
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    public int numIterations = 1000;
    public int burninPeriod = 200;
    public int saveSampleInterval = 10;
    public int optimizeInterval = 50;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 7;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    int numThreads = 1;

    public ParallelTopicModel(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics, 0.01);
    }

    public ParallelTopicModel(int numberOfTopics, double alphaSum, double beta) {
        this(numberOfTopics, alphaSum, beta, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public ParallelTopicModel(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this(ParallelTopicModel.newLabelAlphabet(numberOfTopics), alphaSum, beta, random);
    }

    public ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.beta = beta;
        this.random = random;
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Coded LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setNumThreads(int threads) {
        this.numThreads = threads;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public void addInstances(InstanceList training) {
        FeatureSequence tokens;
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][];
        int[] typeTotals = new int[this.numTypes];
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            tokens = (FeatureSequence)instance.getData();
            for (int position = 0; position < tokens.getLength(); ++position) {
                int type;
                int n = type = tokens.getIndexAtPosition(position);
                typeTotals[n] = typeTotals[n] + 1;
            }
        }
        for (int type = 0; type < this.numTypes; ++type) {
            this.typeTopicCounts[type] = new int[Math.min(this.numTopics, typeTotals[type])];
        }
        doc = 0;
        for (Instance instance : training) {
            ++doc;
            tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < topics.length; ++position) {
                int topic;
                topics[position] = topic = this.random.nextInt(this.numTopics);
            }
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
        this.buildInitialTypeTopicCounts();
        this.initializeHistograms();
    }

    public void buildInitialTypeTopicCounts() {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int type = 0; type < this.numTypes; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            for (int position = 0; position < topicCounts.length && topicCounts[position] > 0; ++position) {
                topicCounts[position] = 0;
            }
        }
        for (TopicAssignment document : this.data) {
            FeatureSequence tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topicSequence = document.topicSequence;
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < tokens.size(); ++position) {
                int topic;
                int n = topic = topics[position];
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                int type = tokens.getIndexAtPosition(position);
                int[] currentTypeTopicCounts = this.typeTopicCounts[type];
                int index = 0;
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                    if (++index == currentTypeTopicCounts.length) {
                        System.out.println("overflow on type " + type);
                    }
                    currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                }
                int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (currentValue == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                    continue;
                }
                currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                    int temp = currentTypeTopicCounts[index];
                    currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                    currentTypeTopicCounts[index - 1] = temp;
                    --index;
                }
            }
        }
    }

    public void sumTypeTopicCounts(WorkerRunnable[] runnables) {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int type = 0; type < this.numTypes; ++type) {
            int[] targetCounts = this.typeTopicCounts[type];
            for (int position = 0; position < targetCounts.length && targetCounts[position] > 0; ++position) {
                targetCounts[position] = 0;
            }
        }
        for (int thread = 0; thread < this.numThreads; ++thread) {
            int[] sourceTotals = runnables[thread].getTokensPerTopic();
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + sourceTotals[topic];
            }
            int[][] sourceTypeTopicCounts = runnables[thread].getTypeTopicCounts();
            for (int type = 0; type < this.numTypes; ++type) {
                int[] sourceCounts = sourceTypeTopicCounts[type];
                int[] targetCounts = this.typeTopicCounts[type];
                for (int sourceIndex = 0; sourceIndex < sourceCounts.length && sourceCounts[sourceIndex] > 0; ++sourceIndex) {
                    int topic = sourceCounts[sourceIndex] & this.topicMask;
                    int count = sourceCounts[sourceIndex] >> this.topicBits;
                    int targetIndex = 0;
                    int currentTopic = targetCounts[targetIndex] & this.topicMask;
                    while (targetCounts[targetIndex] > 0 && currentTopic != topic) {
                        if (++targetIndex == targetCounts.length) {
                            System.out.println("overflow in merging on type " + type);
                        }
                        currentTopic = targetCounts[targetIndex] & this.topicMask;
                    }
                    int currentCount = targetCounts[targetIndex] >> this.topicBits;
                    targetCounts[targetIndex] = (currentCount + count << this.topicBits) + topic;
                    while (targetIndex > 0 && targetCounts[targetIndex] > targetCounts[targetIndex - 1]) {
                        int temp = targetCounts[targetIndex];
                        targetCounts[targetIndex] = targetCounts[targetIndex - 1];
                        targetCounts[targetIndex - 1] = temp;
                        --targetIndex;
                    }
                }
            }
        }
    }

    private void initializeHistograms() {
        int maxTokens = 0;
        int totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            totalTokens += seqLen;
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    public void optimizeAlpha(WorkerRunnable[] runnables) {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
        for (int thread = 0; thread < this.numThreads; ++thread) {
            int[] sourceLengthCounts = runnables[thread].getDocLengthCounts();
            int[][] sourceTopicCounts = runnables[thread].getTopicDocCounts();
            for (int count = 0; count < sourceLengthCounts.length; ++count) {
                if (sourceLengthCounts[count] <= 0) continue;
                int n = count;
                this.docLengthCounts[n] = this.docLengthCounts[n] + sourceLengthCounts[count];
                sourceLengthCounts[count] = 0;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                for (int count = 0; count < sourceTopicCounts[topic].length; ++count) {
                    if (sourceTopicCounts[topic][count] <= 0) continue;
                    int[] nArray = this.topicDocCounts[topic];
                    int n = count;
                    nArray[n] = nArray[n] + sourceTopicCounts[topic][count];
                    sourceTopicCounts[topic][count] = 0;
                }
            }
        }
        this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
    }

    public void estimate() throws IOException {
        long startTime = System.currentTimeMillis();
        WorkerRunnable[] runnables = new WorkerRunnable[this.numThreads];
        int docsPerThread = this.data.size() / this.numThreads;
        int offset = 0;
        for (int thread = 0; thread < this.numThreads; ++thread) {
            int[] runnableTotals = new int[this.numTopics];
            System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
            int[][] runnableCounts = new int[this.numTypes][];
            for (int type = 0; type < this.numTypes; ++type) {
                int[] counts = new int[this.typeTopicCounts[type].length];
                System.arraycopy(this.typeTopicCounts[type], 0, counts, 0, counts.length);
                runnableCounts[type] = counts;
            }
            if (thread == this.numThreads - 1) {
                docsPerThread = this.data.size() - offset;
            }
            runnables[thread] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, new Randoms(), this.data, runnableCounts, runnableTotals, offset, docsPerThread);
            runnables[thread].initializeAlphaStatistics(this.docLengthCounts.length);
            offset += docsPerThread;
        }
        ExecutorService executor = Executors.newFixedThreadPool(this.numThreads);
        for (int iteration = 1; iteration <= this.numIterations; ++iteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && iteration != 0 && iteration % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && iteration % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + iteration));
            }
            for (int thread = 0; thread < this.numThreads; ++thread) {
                if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.saveSampleInterval == 0) {
                    runnables[thread].collectAlphaStatistics();
                }
                executor.submit(runnables[thread]);
            }
            try {
                Thread.sleep(20L);
            }
            catch (InterruptedException e) {
                // empty catch block
            }
            boolean finished = false;
            while (!finished) {
                try {
                    Thread.sleep(10L);
                }
                catch (InterruptedException e) {
                    // empty catch block
                }
                finished = true;
                for (int thread = 0; thread < this.numThreads; ++thread) {
                    finished = finished && runnables[thread].isFinished;
                }
            }
            if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.optimizeInterval == 0) {
                this.optimizeAlpha(runnables);
                System.out.print("[O " + (System.currentTimeMillis() - iterationStart) + "] ");
            }
            System.out.print("[" + (System.currentTimeMillis() - iterationStart) + "] ");
            this.sumTypeTopicCounts(runnables);
            System.out.print("[" + (System.currentTimeMillis() - iterationStart) + "] ");
            for (int thread = 0; thread < this.numThreads; ++thread) {
                int[] runnableTotals = runnables[thread].getTokensPerTopic();
                System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
                int[][] runnableCounts = runnables[thread].getTypeTopicCounts();
                block11: for (int type = 0; type < this.numTypes; ++type) {
                    int[] targetCounts = runnableCounts[type];
                    int[] sourceCounts = this.typeTopicCounts[type];
                    for (int index = 0; index < sourceCounts.length; ++index) {
                        if (sourceCounts[index] != 0) {
                            targetCounts[index] = sourceCounts[index];
                            continue;
                        }
                        if (targetCounts[index] == 0) continue block11;
                        targetCounts[index] = 0;
                    }
                }
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            if (elapsedMillis < 1000L) {
                System.out.print(elapsedMillis + "ms ");
            } else {
                System.out.print(elapsedMillis / 1000L + "s ");
            }
            if (iteration % 10 == 0) {
                System.out.println("<" + iteration + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(this.modelLogLikelihood());
                }
            }
            System.out.flush();
        }
        executor.shutdownNow();
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        int topic;
        FeatureCounter[] wordCountsPerTopic = new FeatureCounter[this.numTopics];
        for (topic = 0; topic < this.numTopics; ++topic) {
            wordCountsPerTopic[topic] = new FeatureCounter(this.alphabet);
        }
        for (int type = 0; type < this.numTypes; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic2 = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                wordCountsPerTopic[topic2].increment(type, count);
            }
        }
        for (topic = 0; topic < this.numTopics; ++topic) {
            RankedFeatureVector rfv = wordCountsPerTopic[topic].toRankedFeatureVector();
            if (usingNewLines) {
                out.println("Topic " + topic);
                int max = rfv.numLocations();
                if (max > numWords) {
                    max = numWords;
                }
                for (int ri = 0; ri < max; ++ri) {
                    int type = rfv.getIndexAtRank(ri);
                    out.println(this.alphabet.lookupObject(type).toString() + "\t" + (int)rfv.getValueAtRank(ri));
                }
                continue;
            }
            out.print(topic + "\t" + this.formatter.format(this.alpha[topic]) + "\t");
            for (int ri = 0; ri < numWords; ++ri) {
                out.print(this.alphabet.lookupObject(rfv.getIndexAtRank(ri)).toString() + " ");
            }
            out.print("\n");
        }
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new FileWriter(f)));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        for (int di = 0; di < this.data.size(); ++di) {
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            pw.print(di);
            pw.print(' ');
            if (this.data.get((int)di).instance.getSource() != null) {
                pw.print(this.data.get((int)di).instance.getSource());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            int docLen = currentDocTopics.length;
            for (int token = 0; token < docLen; ++token) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)docLen);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                pw.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        out.print("#alpha : ");
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out.print(this.alpha[topic] + " ");
        }
        out.println();
        for (int di = 0; di < this.data.size(); ++di) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            String source = "NA";
            if (this.data.get((int)di).instance.getSource() != null) {
                source = this.data.get((int)di).instance.getSource().toString();
            }
            for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic = topicSequence.getIndexAtPosition(pi);
                out.print(di);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(pi);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
            }
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            for (int token = 0; token < docTopics.length; ++token) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        for (int type = 0; type < this.numTypes; ++type) {
            topicCounts = this.typeTopicCounts[type];
            for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                ++nonZeroTypeTopics;
                if (!Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count))) continue;
                System.out.println(count);
                System.exit(1);
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (!Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTopics + (double)this.tokensPerTopic[topic]))) continue;
            System.out.println("after topic " + topic + " " + this.tokensPerTopic[topic]);
            System.exit(1);
        }
        if (Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTopics) - Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public static void main(String[] args) {
        try {
            InstanceList training = InstanceList.load(new File(args[0]));
            int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
            ParallelTopicModel lda = new ParallelTopicModel(numTopics, 50.0, 0.01);
            lda.printLogLikelihood = true;
            lda.setTopicDisplay(50, 7);
            lda.addInstances(training);
            lda.setNumThreads(Integer.parseInt(args[2]));
            lda.estimate();
            System.out.println("printing state");
            lda.printState(new File("state.gz"));
            System.out.println("finished printing");
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

