/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import gnu.trove.TIntHashSet;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;

public class TopicModelDiagnostics {
    int numTopics;
    int numTopWords;
    ArrayList<TreeSet<IDSorter>> topicSortedWords;
    String[][] topicTopWords;
    ArrayList<TopicScores> diagnostics;
    ParallelTopicModel model;
    Alphabet alphabet;
    int[][][] topicCodocumentMatrices;
    double[][] docTopicProportions;
    int[] wordTypeCounts;
    int numTokens = 0;

    public TopicModelDiagnostics(ParallelTopicModel model, int numTopWords) {
        this.numTopics = model.getNumTopics();
        this.numTopWords = numTopWords;
        this.model = model;
        this.alphabet = model.getAlphabet();
        this.topicSortedWords = model.getSortedWords();
        this.topicTopWords = new String[this.numTopics][numTopWords];
        this.diagnostics = new ArrayList();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            boolean position = false;
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            int limit = numTopWords;
            if (sortedWords.size() < numTopWords) {
                limit = sortedWords.size();
            }
            Iterator<IDSorter> iterator = sortedWords.iterator();
            for (int i = 0; i < limit; ++i) {
                IDSorter info = iterator.next();
                this.topicTopWords[topic][i] = (String)this.alphabet.lookupObject(info.getID());
            }
        }
        this.collectDocumentStatistics();
        this.diagnostics.add(this.getTokensPerTopic(model.tokensPerTopic));
        this.diagnostics.add(this.getWordLengthScores());
        this.diagnostics.add(this.getCoherence());
        this.diagnostics.add(this.getDistanceFromUniform());
        this.diagnostics.add(this.getDistanceFromCorpus());
        this.diagnostics.add(this.getEffectiveNumberOfWords());
        this.diagnostics.add(this.getTokenDocumentDiscrepancies());
        this.diagnostics.add(this.getRank1Percent());
    }

    public void collectDocumentStatistics() {
        this.topicCodocumentMatrices = new int[this.numTopics][this.numTopWords][this.numTopWords];
        this.wordTypeCounts = new int[this.alphabet.size()];
        this.numTokens = 0;
        TIntHashSet[] topicTopWordIndices = new TIntHashSet[this.numTopics];
        int[][] topicWordIndicesInOrder = new int[this.numTopics][this.numTopWords];
        TIntHashSet[] docTopicWordIndices = new TIntHashSet[this.numTopics];
        int numDocs = this.model.getData().size();
        this.docTopicProportions = new double[numDocs][this.numTopics];
        int[] topicCounts = new int[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            TIntHashSet wordIndices = new TIntHashSet();
            for (int i = 0; i < this.numTopWords; ++i) {
                int type;
                if (this.topicTopWords[topic][i] == null) continue;
                topicWordIndicesInOrder[topic][i] = type = this.alphabet.lookupIndex(this.topicTopWords[topic][i]);
                wordIndices.add(type);
            }
            topicTopWordIndices[topic] = wordIndices;
            docTopicWordIndices[topic] = new TIntHashSet();
        }
        int doc = 0;
        for (TopicAssignment document : this.model.getData()) {
            FeatureSequence tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topics = document.topicSequence;
            for (int position = 0; position < tokens.size(); ++position) {
                int type = tokens.getIndexAtPosition(position);
                int topic = topics.getIndexAtPosition(position);
                ++this.numTokens;
                int n = type;
                this.wordTypeCounts[n] = this.wordTypeCounts[n] + 1;
                int n2 = topic;
                topicCounts[n2] = topicCounts[n2] + 1;
                double[] dArray = this.docTopicProportions[doc];
                int n3 = topic;
                dArray[n3] = dArray[n3] + 1.0;
                if (!topicTopWordIndices[topic].contains(type)) continue;
                docTopicWordIndices[topic].add(type);
            }
            int docLength = tokens.size();
            if (docLength > 0) {
                for (int topic = 0; topic < this.numTopics; ++topic) {
                    double[] dArray = this.docTopicProportions[doc];
                    int n = topic;
                    dArray[n] = dArray[n] / (double)docLength;
                    TIntHashSet supportedWords = docTopicWordIndices[topic];
                    int[] indices = topicWordIndicesInOrder[topic];
                    if (topicCounts[topic] <= 0) continue;
                    for (int i = 0; i < this.numTopWords; ++i) {
                        if (!supportedWords.contains(indices[i])) continue;
                        for (int j = i; j < this.numTopWords; ++j) {
                            if (i == j) {
                                int[] nArray = this.topicCodocumentMatrices[topic][i];
                                int n4 = i;
                                nArray[n4] = nArray[n4] + 1;
                                continue;
                            }
                            if (!supportedWords.contains(indices[j])) continue;
                            int[] nArray = this.topicCodocumentMatrices[topic][i];
                            int n5 = j;
                            nArray[n5] = nArray[n5] + 1;
                            int[] nArray2 = this.topicCodocumentMatrices[topic][j];
                            int n6 = i;
                            nArray2[n6] = nArray2[n6] + 1;
                        }
                    }
                    docTopicWordIndices[topic].clear();
                    topicCounts[topic] = 0;
                }
            }
            ++doc;
        }
    }

    public int[][] getCodocumentMatrix(int topic) {
        return this.topicCodocumentMatrices[topic];
    }

    public TopicScores getTokensPerTopic(int[] tokensPerTopic) {
        TopicScores scores = new TopicScores("tokens", this.numTopics, this.numTopWords);
        for (int topic = 0; topic < this.numTopics; ++topic) {
            scores.setTopicScore(topic, tokensPerTopic[topic]);
        }
        return scores;
    }

    public TopicScores getDistanceFromUniform() {
        int[] tokensPerTopic = this.model.tokensPerTopic;
        TopicScores scores = new TopicScores("uniform_dist", this.numTopics, this.numTopWords);
        scores.wordScoresDefined = true;
        int numTypes = this.alphabet.size();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            double topicScore = 0.0;
            int position = 0;
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            for (IDSorter info : sortedWords) {
                int type = info.getID();
                double count = info.getWeight();
                double score = count / (double)tokensPerTopic[topic] * Math.log(count * (double)numTypes / (double)tokensPerTopic[topic]);
                if (position < this.numTopWords) {
                    scores.setTopicWordScore(topic, position, score);
                }
                topicScore += score;
                ++position;
            }
            scores.setTopicScore(topic, topicScore);
        }
        return scores;
    }

    public TopicScores getEffectiveNumberOfWords() {
        int[] tokensPerTopic = this.model.tokensPerTopic;
        TopicScores scores = new TopicScores("eff_num_words", this.numTopics, this.numTopWords);
        int numTypes = this.alphabet.size();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            double sumSquaredProbabilities = 0.0;
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            for (IDSorter info : sortedWords) {
                int type = info.getID();
                double probability = info.getWeight() / (double)tokensPerTopic[topic];
                sumSquaredProbabilities += probability * probability;
            }
            scores.setTopicScore(topic, 1.0 / sumSquaredProbabilities);
        }
        return scores;
    }

    public TopicScores getDistanceFromCorpus() {
        int[] tokensPerTopic = this.model.tokensPerTopic;
        TopicScores scores = new TopicScores("corpus_dist", this.numTopics, this.numTopWords);
        scores.wordScoresDefined = true;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            double coefficient = (double)this.numTokens / (double)tokensPerTopic[topic];
            double topicScore = 0.0;
            int position = 0;
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            for (IDSorter info : sortedWords) {
                int type = info.getID();
                double count = info.getWeight();
                double score = count / (double)tokensPerTopic[topic] * Math.log(coefficient * count / (double)this.wordTypeCounts[type]);
                if (position < this.numTopWords) {
                    scores.setTopicWordScore(topic, position, score);
                }
                topicScore += score;
                ++position;
            }
            scores.setTopicScore(topic, topicScore);
        }
        return scores;
    }

    public TopicScores getTokenDocumentDiscrepancies() {
        TopicScores scores = new TopicScores("token-doc-diff", this.numTopics, this.numTopWords);
        scores.wordScoresDefined = true;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            int position;
            int[][] matrix = this.topicCodocumentMatrices[topic];
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            double topicScore = 0.0;
            double[] wordDistribution = new double[this.numTopWords];
            double[] docDistribution = new double[this.numTopWords];
            double wordSum = 0.0;
            double docSum = 0.0;
            Iterator<IDSorter> iterator = sortedWords.iterator();
            for (position = 0; iterator.hasNext() && position < this.numTopWords; ++position) {
                IDSorter info = iterator.next();
                wordDistribution[position] = info.getWeight();
                docDistribution[position] = matrix[position][position];
                wordSum += wordDistribution[position];
                docSum += docDistribution[position];
            }
            for (position = 0; position < this.numTopWords; ++position) {
                double p = wordDistribution[position] / wordSum;
                double q = docDistribution[position] / docSum;
                double meanProb = 0.5 * (p + q);
                double score = 0.0;
                if (p > 0.0) {
                    score += 0.5 * p * Math.log(p / meanProb);
                }
                if (q > 0.0) {
                    score += 0.5 * q * Math.log(q / meanProb);
                }
                scores.setTopicWordScore(topic, position, score);
                topicScore += score;
            }
            scores.setTopicScore(topic, topicScore);
        }
        return scores;
    }

    public TopicScores getWordLengthScores() {
        TopicScores scores = new TopicScores("word-length", this.numTopics, this.numTopWords);
        scores.wordScoresDefined = true;
        double meanLength = 0.0;
        int totalWords = 0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            for (int position = 0; position < this.topicTopWords[topic].length && this.topicTopWords[topic][position] != null; ++position) {
                meanLength += (double)this.topicTopWords[topic][position].length();
                ++totalWords;
            }
        }
        meanLength /= (double)totalWords;
        double lengthVariance = 0.0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            for (int position = 0; position < this.topicTopWords[topic].length && this.topicTopWords[topic][position] != null; ++position) {
                int length = this.topicTopWords[topic][position].length();
                lengthVariance += ((double)length - meanLength) * ((double)length - meanLength);
            }
        }
        double lengthSD = Math.sqrt(lengthVariance /= (double)(totalWords - 1));
        for (int topic = 0; topic < this.numTopics; ++topic) {
            for (int position = 0; position < this.topicTopWords[topic].length && this.topicTopWords[topic][position] != null; ++position) {
                int length = this.topicTopWords[topic][position].length();
                scores.addToTopicScore(topic, ((double)length - meanLength) / lengthSD);
                scores.setTopicWordScore(topic, position, ((double)length - meanLength) / lengthSD);
            }
        }
        return scores;
    }

    public TopicScores getCoherence() {
        TopicScores scores = new TopicScores("coherence", this.numTopics, this.numTopWords);
        scores.wordScoresDefined = true;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            int[][] matrix = this.topicCodocumentMatrices[topic];
            double topicScore = 0.0;
            for (int row = 0; row < this.numTopWords; ++row) {
                double rowScore = 0.0;
                double minScore = 0.0;
                for (int col = 0; col < row; ++col) {
                    double score = Math.log(((double)matrix[row][col] + 1.0) / ((double)matrix[col][col] + 1.0));
                    rowScore += score;
                    if (!(score < minScore)) continue;
                    minScore = score;
                }
                topicScore += rowScore;
                scores.setTopicWordScore(topic, row, minScore);
            }
            scores.setTopicScore(topic, topicScore);
        }
        return scores;
    }

    public TopicScores getRank1Percent() {
        TopicScores scores = new TopicScores("rank_1_docs", this.numTopics, this.numTopWords);
        int[] numRank1Documents = new int[this.numTopics];
        int[] numNonZeroDocuments = new int[this.numTopics];
        for (int doc = 0; doc < this.docTopicProportions.length; ++doc) {
            int maxTopic = -1;
            double maxTopicProb = 0.0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (this.docTopicProportions[doc][topic] > 0.0) {
                    int n = topic;
                    numNonZeroDocuments[n] = numNonZeroDocuments[n] + 1;
                }
                if (!(this.docTopicProportions[doc][topic] > maxTopicProb)) continue;
                maxTopic = topic;
                maxTopicProb = this.docTopicProportions[doc][topic];
            }
            if (maxTopic == -1) continue;
            int n = maxTopic;
            numRank1Documents[n] = numRank1Documents[n] + 1;
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            scores.setTopicScore(topic, (double)numRank1Documents[topic] / (double)numNonZeroDocuments[topic]);
        }
        return scores;
    }

    public String toString() {
        StringBuilder out = new StringBuilder();
        Formatter formatter = new Formatter(out, Locale.US);
        for (int topic = 0; topic < this.numTopics; ++topic) {
            formatter.format("Topic %d", topic);
            for (TopicScores scores : this.diagnostics) {
                formatter.format("\t%s=%.4f", scores.name, scores.scores[topic]);
            }
            formatter.format("\n", new Object[0]);
            for (int position = 0; position < this.topicTopWords[topic].length && this.topicTopWords[topic][position] != null; ++position) {
                formatter.format("  %s", this.topicTopWords[topic][position]);
                for (TopicScores scores : this.diagnostics) {
                    if (!scores.wordScoresDefined) continue;
                    formatter.format("\t%s=%.4f", scores.name, scores.topicWordScores[topic][position]);
                }
                out.append("\n");
            }
        }
        return out.toString();
    }

    public String toXML() {
        int[] tokensPerTopic = this.model.tokensPerTopic;
        StringBuilder out = new StringBuilder();
        Formatter formatter = new Formatter(out, Locale.US);
        out.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
        out.append("<model>\n");
        for (int topic = 0; topic < this.numTopics; ++topic) {
            int[][] matrix = this.topicCodocumentMatrices[topic];
            formatter.format("<topic id='%d'", topic);
            for (TopicScores scores : this.diagnostics) {
                formatter.format(" %s='%.4f'", scores.name, scores.scores[topic]);
            }
            out.append(">\n");
            TreeSet<IDSorter> sortedWords = this.topicSortedWords.get(topic);
            int limit = this.numTopWords;
            if (sortedWords.size() < this.numTopWords) {
                limit = sortedWords.size();
            }
            double cumulativeProbability = 0.0;
            Iterator<IDSorter> iterator = sortedWords.iterator();
            for (int position = 0; position < limit; ++position) {
                IDSorter info = iterator.next();
                double probability = info.getWeight() / (double)tokensPerTopic[topic];
                formatter.format("<word rank='%d' count='%.0f' prob='%.5f' cumulative='%.5f' docs='%d'", position + 1, info.getWeight(), probability, cumulativeProbability += probability, matrix[position][position]);
                for (TopicScores scores : this.diagnostics) {
                    if (!scores.wordScoresDefined) continue;
                    formatter.format(" %s='%.4f'", scores.name, scores.topicWordScores[topic][position]);
                }
                formatter.format(">%s</word>\n", this.topicTopWords[topic][position]);
            }
            out.append("</topic>\n");
        }
        out.append("</model>\n");
        return out.toString();
    }

    public static void main(String[] args) throws Exception {
        InstanceList instances = InstanceList.load(new File(args[0]));
        int numTopics = Integer.parseInt(args[1]);
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);
        model.addInstances(instances);
        model.setNumIterations(1000);
        model.estimate();
        TopicModelDiagnostics diagnostics = new TopicModelDiagnostics(model, 20);
        if (args.length == 3) {
            PrintWriter out = new PrintWriter(args[2]);
            out.println(diagnostics.toXML());
            out.close();
        }
    }

    public class TopicScores {
        public String name;
        public double[] scores;
        public double[][] topicWordScores;
        public boolean wordScoresDefined = false;

        public TopicScores(String name, int numTopics, int numWords) {
            this.name = name;
            this.scores = new double[numTopics];
            this.topicWordScores = new double[numTopics][numWords];
        }

        public void setTopicScore(int topic, double score) {
            this.scores[topic] = score;
        }

        public void addToTopicScore(int topic, double score) {
            int n = topic;
            this.scores[n] = this.scores[n] + score;
        }

        public void setTopicWordScore(int topic, int wordPosition, double score) {
            this.topicWordScores[topic][wordPosition] = score;
            this.wordScoresDefined = true;
        }
    }
}

