/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class PolylingualTopicModel
implements Serializable {
    int numLanguages = 1;
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected LabelAlphabet topicAlphabet;
    protected int numStopwords = 0;
    protected int numTopics;
    HashSet<String> testingIDs = null;
    protected int topicMask;
    protected int topicBits;
    protected Alphabet[] alphabets;
    protected int[] vocabularySizes;
    protected double[] alpha;
    protected double alphaSum;
    protected double[] betas;
    protected double[] betaSums;
    protected int[] languageMaxTypeCounts;
    public static final double DEFAULT_BETA = 0.01;
    protected double[] languageSmoothingOnlyMasses;
    protected double[][] languageCachedCoefficients;
    int topicTermCount = 0;
    int betaTopicCount = 0;
    int smoothingOnlyCount = 0;
    protected int[] oneDocTopicCounts;
    protected int[][][] languageTypeTopicCounts;
    protected int[][] languageTokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    protected int iterationsSoFar = 1;
    public int numIterations = 1000;
    public int burninPeriod = 5;
    public int saveSampleInterval = 5;
    public int optimizeInterval = 10;
    public int showTopicsInterval = 10;
    public int wordsPerTopic = 7;
    protected int outputModelInterval = 0;
    protected String outputModelFilename;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;

    public PolylingualTopicModel(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics);
    }

    public PolylingualTopicModel(int numberOfTopics, double alphaSum) {
        this(numberOfTopics, alphaSum, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public PolylingualTopicModel(int numberOfTopics, double alphaSum, Randoms random) {
        this(PolylingualTopicModel.newLabelAlphabet(numberOfTopics), alphaSum, random);
    }

    public PolylingualTopicModel(LabelAlphabet topicAlphabet, double alphaSum, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.random = random;
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("Polylingual LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public void loadTestingIDs(File testingIDFile) throws IOException {
        this.testingIDs = new HashSet();
        BufferedReader in = new BufferedReader(new FileReader(testingIDFile));
        String id = null;
        while ((id = in.readLine()) != null) {
            this.testingIDs.add(id);
        }
        in.close();
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setModelOutput(int interval, String filename) {
        this.outputModelInterval = interval;
        this.outputModelFilename = filename;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public void addInstances(InstanceList[] training) {
        this.numLanguages = training.length;
        this.languageTokensPerTopic = new int[this.numLanguages][this.numTopics];
        this.alphabets = new Alphabet[this.numLanguages];
        this.vocabularySizes = new int[this.numLanguages];
        this.betas = new double[this.numLanguages];
        this.betaSums = new double[this.numLanguages];
        this.languageMaxTypeCounts = new int[this.numLanguages];
        this.languageTypeTopicCounts = new int[this.numLanguages][][];
        int numInstances = training[0].size();
        HashSet[] stoplists = new HashSet[this.numLanguages];
        for (int language = 0; language < this.numLanguages; ++language) {
            if (training[language].size() != numInstances) {
                System.err.println("Warning: language " + language + " has " + training[language].size() + " instances, lang 0 has " + numInstances);
            }
            this.alphabets[language] = training[language].getDataAlphabet();
            this.vocabularySizes[language] = this.alphabets[language].size();
            this.betas[language] = 0.01;
            this.betaSums[language] = this.betas[language] * (double)this.vocabularySizes[language];
            this.languageTypeTopicCounts[language] = new int[this.vocabularySizes[language]][];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] typeTotals = new int[this.vocabularySizes[language]];
            for (Instance instance : training[language]) {
                if (this.testingIDs != null && this.testingIDs.contains(instance.getName())) continue;
                FeatureSequence tokens = (FeatureSequence)instance.getData();
                for (int position = 0; position < tokens.getLength(); ++position) {
                    int type;
                    int n = type = tokens.getIndexAtPosition(position);
                    typeTotals[n] = typeTotals[n] + 1;
                }
            }
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                if (typeTotals[type] > this.languageMaxTypeCounts[language]) {
                    this.languageMaxTypeCounts[language] = typeTotals[type];
                }
                typeTopicCounts[type] = new int[Math.min(this.numTopics, typeTotals[type])];
            }
        }
        for (int doc = 0; doc < numInstances; ++doc) {
            if (this.testingIDs != null && this.testingIDs.contains(((Instance)training[0].get(doc)).getName())) continue;
            Instance[] instances = new Instance[this.numLanguages];
            LabelSequence[] topicSequences = new LabelSequence[this.numLanguages];
            for (int language = 0; language < this.numLanguages; ++language) {
                int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
                int[] tokensPerTopic = this.languageTokensPerTopic[language];
                instances[language] = (Instance)training[language].get(doc);
                FeatureSequence tokens = (FeatureSequence)instances[language].getData();
                topicSequences[language] = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
                int[] topics = topicSequences[language].getFeatures();
                for (int position = 0; position < tokens.size(); ++position) {
                    int topic;
                    int type = tokens.getIndexAtPosition(position);
                    int[] currentTypeTopicCounts = typeTopicCounts[type];
                    topics[position] = topic = this.random.nextInt(this.numTopics);
                    int n = topic;
                    tokensPerTopic[n] = tokensPerTopic[n] + 1;
                    int index = 0;
                    int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                        currentTopic = currentTypeTopicCounts[++index] & this.topicMask;
                    }
                    int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    if (currentValue == 0) {
                        currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                        continue;
                    }
                    currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                    while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                        int temp = currentTypeTopicCounts[index];
                        currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                        currentTypeTopicCounts[index - 1] = temp;
                        --index;
                    }
                }
            }
            TopicAssignment t = new TopicAssignment(instances, topicSequences);
            this.data.add(t);
        }
        this.initializeHistograms();
        this.languageSmoothingOnlyMasses = new double[this.numLanguages];
        this.languageCachedCoefficients = new double[this.numLanguages][this.numTopics];
        this.cacheValues();
    }

    private void initializeHistograms() {
        int maxTokens = 0;
        int totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            int length = 0;
            for (LabelSequence sequence : this.data.get((int)doc).topicSequences) {
                length += sequence.getLength();
            }
            if (length > maxTokens) {
                maxTokens = length;
            }
            totalTokens += length;
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    private void cacheValues() {
        for (int language = 0; language < this.numLanguages; ++language) {
            this.languageSmoothingOnlyMasses[language] = 0.0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = language;
                this.languageSmoothingOnlyMasses[n] = this.languageSmoothingOnlyMasses[n] + this.alpha[topic] * this.betas[language] / ((double)this.languageTokensPerTopic[language][topic] + this.betaSums[language]);
                this.languageCachedCoefficients[language][topic] = this.alpha[topic] / ((double)this.languageTokensPerTopic[language][topic] + this.betaSums[language]);
            }
        }
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
    }

    public void estimate() throws IOException {
        this.estimate(this.numIterations);
    }

    public void estimate(int iterationsThisRound) throws IOException {
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        long totalTime = 0L;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.optimizeBetas();
                this.clearHistograms();
                this.cacheValues();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            for (int doc = 0; doc < this.data.size(); ++doc) {
                this.sampleTopicsForOneDoc(this.data.get(doc), this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0);
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            totalTime += elapsedMillis;
            if ((this.iterationsSoFar + 1) % 10 == 0) {
                double ll = this.modelLogLikelihood();
                System.out.println(elapsedMillis + "\t" + totalTime + "\t" + ll);
            } else {
                System.out.print(elapsedMillis + " ");
            }
            ++this.iterationsSoFar;
        }
    }

    public void optimizeBetas() {
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] countHistogram = new int[this.languageMaxTypeCounts[language] + 1];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                int[] counts = typeTopicCounts[type];
                for (int index = 0; index < counts.length && counts[index] > 0; ++index) {
                    int count;
                    int n = count = counts[index] >> this.topicBits;
                    countHistogram[n] = countHistogram[n] + 1;
                }
            }
            int maxTopicSize = 0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (tokensPerTopic[topic] <= maxTopicSize) continue;
                maxTopicSize = tokensPerTopic[topic];
            }
            int[] topicSizeHistogram = new int[maxTopicSize + 1];
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int n = tokensPerTopic[topic];
                topicSizeHistogram[n] = topicSizeHistogram[n] + 1;
            }
            this.betaSums[language] = Dirichlet.learnSymmetricConcentration(countHistogram, topicSizeHistogram, this.vocabularySizes[language], this.betaSums[language]);
            this.betas[language] = this.betaSums[language] / (double)this.vocabularySizes[language];
        }
    }

    protected void sampleTopicsForOneDoc(TopicAssignment topicAssignment, boolean shouldSaveState) {
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures();
            int docLength = topicAssignment.topicSequences[language].getLength();
            for (int position = 0; position < docLength; ++position) {
                int n = oneDocTopics[position];
                localTopicCounts[n] = localTopicCounts[n] + 1;
            }
        }
        int denseIndex = 0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            if (localTopicCounts[topic] == 0) continue;
            localTopicIndex[denseIndex] = topic;
            ++denseIndex;
        }
        int nonZeroTopics = denseIndex;
        for (int language = 0; language < this.numLanguages; ++language) {
            int[] oneDocTopics = topicAssignment.topicSequences[language].getFeatures();
            int docLength = topicAssignment.topicSequences[language].getLength();
            FeatureSequence tokenSequence = (FeatureSequence)topicAssignment.instances[language].getData();
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            double beta = this.betas[language];
            double betaSum = this.betaSums[language];
            double smoothingOnlyMass = this.languageSmoothingOnlyMasses[language];
            double[] cachedCoefficients = this.languageCachedCoefficients[language];
            double topicBetaMass = 0.0;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                int topic = localTopicIndex[denseIndex];
                int n = localTopicCounts[topic];
                topicBetaMass += beta * (double)n / ((double)tokensPerTopic[topic] + betaSum);
                cachedCoefficients[topic] = (this.alpha[topic] + (double)n) / ((double)tokensPerTopic[topic] + betaSum);
            }
            double topicTermMass = 0.0;
            double[] topicTermScores = new double[this.numTopics];
            for (int position = 0; position < docLength; ++position) {
                int temp;
                double sample;
                int currentValue;
                int type = tokenSequence.getIndexAtPosition(position);
                int oldTopic = oneDocTopics[position];
                if (oldTopic == -1) continue;
                int[] currentTypeTopicCounts = typeTopicCounts[type];
                smoothingOnlyMass -= this.alpha[oldTopic] * beta / ((double)tokensPerTopic[oldTopic] + betaSum);
                topicBetaMass -= beta * (double)localTopicCounts[oldTopic] / ((double)tokensPerTopic[oldTopic] + betaSum);
                int n = oldTopic;
                localTopicCounts[n] = localTopicCounts[n] - 1;
                if (localTopicCounts[oldTopic] == 0) {
                    denseIndex = 0;
                    while (localTopicIndex[denseIndex] != oldTopic) {
                        ++denseIndex;
                    }
                    while (denseIndex < nonZeroTopics) {
                        if (denseIndex < localTopicIndex.length - 1) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                        }
                        ++denseIndex;
                    }
                    --nonZeroTopics;
                }
                int n2 = oldTopic;
                tokensPerTopic[n2] = tokensPerTopic[n2] - 1;
                smoothingOnlyMass += this.alpha[oldTopic] * beta / ((double)tokensPerTopic[oldTopic] + betaSum);
                topicBetaMass += beta * (double)localTopicCounts[oldTopic] / ((double)tokensPerTopic[oldTopic] + betaSum);
                cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)tokensPerTopic[oldTopic] + betaSum);
                int index = 0;
                boolean alreadyDecremented = false;
                topicTermMass = 0.0;
                while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                    int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    if (!alreadyDecremented && currentTopic == oldTopic) {
                        currentTypeTopicCounts[index] = --currentValue == 0 ? 0 : (currentValue << this.topicBits) + oldTopic;
                        for (int subIndex = index; subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]; ++subIndex) {
                            int temp2 = currentTypeTopicCounts[subIndex];
                            currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
                            currentTypeTopicCounts[subIndex + 1] = temp2;
                        }
                        alreadyDecremented = true;
                        continue;
                    }
                    double score = cachedCoefficients[currentTopic] * (double)currentValue;
                    topicTermMass += score;
                    topicTermScores[index] = score;
                    ++index;
                }
                double origSample = sample = this.random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
                int newTopic = -1;
                if (sample < topicTermMass) {
                    int i = -1;
                    while (sample > 0.0) {
                        sample -= topicTermScores[++i];
                    }
                    newTopic = currentTypeTopicCounts[i] & this.topicMask;
                    currentValue = currentTypeTopicCounts[i] >> this.topicBits;
                    currentTypeTopicCounts[i] = (currentValue + 1 << this.topicBits) + newTopic;
                    while (i > 0 && currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
                        temp = currentTypeTopicCounts[i];
                        currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
                        currentTypeTopicCounts[i - 1] = temp;
                        --i;
                    }
                } else {
                    if ((sample -= topicTermMass) < topicBetaMass) {
                        sample /= beta;
                        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                            int topic = localTopicIndex[denseIndex];
                            if (!((sample -= (double)localTopicCounts[topic] / ((double)tokensPerTopic[topic] + betaSum)) <= 0.0)) continue;
                            newTopic = topic;
                            break;
                        }
                    } else {
                        sample -= topicBetaMass;
                        sample /= beta;
                        newTopic = 0;
                        sample -= this.alpha[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                        while (sample > 0.0) {
                            sample -= this.alpha[++newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                        }
                    }
                    index = 0;
                    while (currentTypeTopicCounts[index] > 0 && (currentTypeTopicCounts[index] & this.topicMask) != newTopic) {
                        ++index;
                    }
                    if (currentTypeTopicCounts[index] == 0) {
                        currentTypeTopicCounts[index] = (1 << this.topicBits) + newTopic;
                    } else {
                        currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + newTopic;
                        while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                            temp = currentTypeTopicCounts[index];
                            currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                            currentTypeTopicCounts[index - 1] = temp;
                            --index;
                        }
                    }
                }
                if (newTopic == -1) {
                    System.err.println("PolylingualTopicModel sampling error: " + origSample + " " + sample + " " + smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                    newTopic = this.numTopics - 1;
                }
                oneDocTopics[position] = newTopic;
                smoothingOnlyMass -= this.alpha[newTopic] * beta / ((double)tokensPerTopic[newTopic] + betaSum);
                topicBetaMass -= beta * (double)localTopicCounts[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                int n3 = newTopic;
                localTopicCounts[n3] = localTopicCounts[n3] + 1;
                if (localTopicCounts[newTopic] == 1) {
                    for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                        localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                    }
                    localTopicIndex[denseIndex] = newTopic;
                    ++nonZeroTopics;
                }
                int n4 = newTopic;
                tokensPerTopic[n4] = tokensPerTopic[n4] + 1;
                cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)tokensPerTopic[newTopic] + betaSum);
                topicBetaMass += beta * (double)localTopicCounts[newTopic] / ((double)tokensPerTopic[newTopic] + betaSum);
                this.languageSmoothingOnlyMasses[language] = smoothingOnlyMass += this.alpha[newTopic] * beta / ((double)tokensPerTopic[newTopic] + betaSum);
            }
        }
        if (shouldSaveState) {
            int totalLength = 0;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                int topic = localTopicIndex[denseIndex];
                int[] nArray = this.topicDocCounts[topic];
                int n = localTopicCounts[topic];
                nArray[n] = nArray[n] + 1;
                totalLength += localTopicCounts[topic];
            }
            int n = totalLength;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
        }
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        TreeSet[][] languageTopicSortedWords = new TreeSet[this.numLanguages][this.numTopics];
        for (int language = 0; language < this.numLanguages; ++language) {
            TreeSet[] topicSortedWords = languageTopicSortedWords[language];
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            for (int topic = 0; topic < this.numTopics; ++topic) {
                topicSortedWords[topic] = new TreeSet();
            }
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                int[] topicCounts = typeTopicCounts[type];
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int topic = topicCounts[index] & this.topicMask;
                    int count = topicCounts[index] >> this.topicBits;
                    topicSortedWords[topic].add(new IDSorter(type, count));
                }
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out.println(topic + "\t" + this.formatter.format(this.alpha[topic]));
            for (int language = 0; language < this.numLanguages; ++language) {
                out.print(" " + language + "\t" + this.languageTokensPerTopic[language][topic] + "\t" + this.betas[language] + "\t");
                TreeSet sortedWords = languageTopicSortedWords[language][topic];
                Alphabet alphabet = this.alphabets[language];
                Iterator iterator = sortedWords.iterator();
                for (int word = 1; iterator.hasNext() && word < numWords; ++word) {
                    IDSorter info = (IDSorter)iterator.next();
                    out.print(alphabet.lookupObject(info.getID()) + " ");
                }
                out.println();
            }
        }
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(f, "UTF-8"));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        for (int di = 0; di < this.data.size(); ++di) {
            pw.print(di);
            pw.print(' ');
            int totalLength = 0;
            for (int language = 0; language < this.numLanguages; ++language) {
                LabelSequence topicSequence = this.data.get((int)di).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();
                int docLength = topicSequence.getLength();
                totalLength += docLength;
                for (int token = 0; token < docLength; ++token) {
                    int n = currentDocTopics[token];
                    topicCounts[n] = topicCounts[n] + 1;
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)totalLength);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                pw.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream((OutputStream)new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))), false, "UTF-8");
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc lang pos typeindex type topic");
        for (int doc = 0; doc < this.data.size(); ++doc) {
            for (int language = 0; language < this.numLanguages; ++language) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instances[language].getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequences[language];
                for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                    int type = tokenSequence.getIndexAtPosition(pi);
                    int topic = topicSequence.getIndexAtPosition(pi);
                    out.print(doc);
                    out.print(' ');
                    out.print(language);
                    out.print(' ');
                    out.print(pi);
                    out.print(' ');
                    out.print(type);
                    out.print(' ');
                    out.print(this.alphabets[language].lookupObject(type));
                    out.print(' ');
                    out.print(topic);
                    out.println();
                }
            }
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            int totalLength = 0;
            for (int language = 0; language < this.numLanguages; ++language) {
                LabelSequence topicSequence = this.data.get((int)doc).topicSequences[language];
                int[] currentDocTopics = topicSequence.getFeatures();
                totalLength += topicSequence.getLength();
                for (int token = 0; token < topicSequence.getLength(); ++token) {
                    int n = currentDocTopics[token];
                    topicCounts[n] = topicCounts[n] + 1;
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)totalLength);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        for (int language = 0; language < this.numLanguages; ++language) {
            int[][] typeTopicCounts = this.languageTypeTopicCounts[language];
            int[] tokensPerTopic = this.languageTokensPerTopic[language];
            double beta = this.betas[language];
            int nonZeroTypeTopics = 0;
            for (int type = 0; type < this.vocabularySizes[language]; ++type) {
                topicCounts = typeTopicCounts[type];
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int topic = topicCounts[index] & this.topicMask;
                    int count = topicCounts[index] >> this.topicBits;
                    ++nonZeroTypeTopics;
                    if (!Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(beta + (double)count))) continue;
                    System.out.println(count);
                    System.exit(1);
                }
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (!Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(beta * (double)this.numTopics + (double)tokensPerTopic[topic]))) continue;
                System.out.println("after topic " + topic + " " + tokensPerTopic[topic]);
                System.exit(1);
            }
            logLikelihood += Dirichlet.logGammaStirling(beta * (double)this.numTopics) - Dirichlet.logGammaStirling(beta) * (double)nonZeroTypeTopics;
        }
        if (Double.isNaN(logLikelihood)) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public static void main(String[] args) throws IOException {
        if (args.length < 4) {
            System.err.println("Usage: PolylingualTopicModel [num topics] [file to save state] [testing IDs file] [language 0 instances] ...");
            System.exit(1);
        }
        int numTopics = Integer.parseInt(args[0]);
        String stateFileName = args[1];
        File testingIDsFile = new File(args[2]);
        InstanceList[] training = new InstanceList[args.length - 3];
        for (int language = 0; language < training.length; ++language) {
            training[language] = InstanceList.load(new File(args[language + 3]));
            System.err.println("loaded " + args[language + 3]);
        }
        PolylingualTopicModel lda = new PolylingualTopicModel(numTopics, 2.0);
        lda.printLogLikelihood = true;
        lda.setTopicDisplay(50, 7);
        lda.loadTestingIDs(testingIDsFile);
        lda.addInstances(training);
        lda.setSaveState(200, stateFileName);
        lda.estimate();
        lda.printState(new File(stateFileName));
    }

    public class TopicAssignment
    implements Serializable {
        public Instance[] instances;
        public LabelSequence[] topicSequences;
        public Labeling topicDistribution;

        public TopicAssignment(Instance[] instances, LabelSequence[] topicSequences) {
            this.instances = instances;
            this.topicSequences = topicSequences;
        }
    }
}

