/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

public class FeatureConstraintUtil {
    private static Logger logger = MalletLogger.getLogger(FeatureConstraintUtil.class.getName());

    public static HashMap<Integer, double[]> readConstraintsFromFile(String filename, InstanceList data) {
        if (FeatureConstraintUtil.testConstraintsFileIndexBased(filename)) {
            return FeatureConstraintUtil.readConstraintsFromFileIndex(filename, data);
        }
        return FeatureConstraintUtil.readConstraintsFromFileString(filename, data);
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileString(String filename, InstanceList data) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                String featureName = split[0];
                int featureIndex = data.getDataAlphabet().lookupIndex(featureName, false);
                assert (split.length - 1 == data.getTargetAlphabet().size());
                double[] probs = new double[split.length - 1];
                for (int index = 1; index < split.length; ++index) {
                    double prob;
                    String[] labelSplit = split[index].split(":");
                    int li = data.getTargetAlphabet().lookupIndex(labelSplit[0], false);
                    probs[li] = prob = Double.parseDouble(labelSplit[1]);
                }
                constraints.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> readConstraintsFromFileIndex(String filename, InstanceList data) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                int featureIndex = Integer.parseInt(split[0]);
                assert (split.length - 1 == data.getTargetAlphabet().size());
                double[] probs = new double[split.length - 1];
                for (int index = 1; index < split.length; ++index) {
                    double prob;
                    probs[index - 1] = prob = Double.parseDouble(split[index]);
                }
                constraints.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return constraints;
    }

    private static boolean testConstraintsFileIndexBased(String filename) {
        File file = new File(filename);
        String firstLine = "";
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            firstLine = reader.readLine();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return !firstLine.contains(":");
    }

    public static ArrayList<Integer> selectFeaturesByInfoGain(InstanceList list, int numFeatures) {
        ArrayList<Integer> features = new ArrayList<Integer>();
        InfoGain infogain = new InfoGain(list);
        for (int rank = 0; rank < numFeatures; ++rank) {
            features.add(infogain.getIndexAtRank(rank));
        }
        return features;
    }

    public static ArrayList<Integer> selectTopLDAFeatures(int numSelFeatures, ParallelTopicModel lda, Alphabet alphabet) {
        ArrayList<Integer> features = new ArrayList<Integer>();
        Alphabet seqAlphabet = lda.getAlphabet();
        int numTopics = lda.getNumTopics();
        Object[][] sorted = lda.getTopWords(seqAlphabet.size());
        for (int pos = 0; pos < seqAlphabet.size(); ++pos) {
            for (int ti = 0; ti < numTopics; ++ti) {
                String feat = sorted[ti][pos].toString();
                int fi = alphabet.lookupIndex(feat, false);
                if (fi < 0 || features.contains(fi)) continue;
                logger.info("Selected feature: " + feat);
                features.add(fi);
                if (features.size() != numSelFeatures) continue;
                return features;
            }
        }
        return features;
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features) {
        return FeatureConstraintUtil.setTargetsUsingData(list, features, true);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean normalize) {
        return FeatureConstraintUtil.setTargetsUsingData(list, features, false, normalize);
    }

    public static HashMap<Integer, double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean useValues, boolean normalize) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        double[][] featureLabelCounts = FeatureConstraintUtil.getFeatureLabelCounts(list, useValues);
        for (int i = 0; i < features.size(); ++i) {
            int fi = features.get(i);
            if (fi == list.getDataAlphabet().size()) continue;
            double[] prob = featureLabelCounts[fi];
            if (normalize) {
                MatrixOps.plusEquals(prob, 1.0E-8);
                MatrixOps.timesEquals(prob, 1.0 / MatrixOps.sum(prob));
            }
            constraints.put(fi, prob);
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> setTargetsUsingHeuristic(HashMap<Integer, ArrayList<Integer>> labeledFeatures, int numLabels, double majorityProb) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        for (int fi : labeledFeatures.keySet()) {
            ArrayList<Integer> labels = labeledFeatures.get(fi);
            constraints.put(fi, FeatureConstraintUtil.getHeuristicPrior(labels, numLabels, majorityProb));
        }
        return constraints;
    }

    public static HashMap<Integer, double[]> setTargetsUsingFeatureVoting(HashMap<Integer, ArrayList<Integer>> labeledFeatures, InstanceList trainingData) {
        HashMap<Integer, double[]> constraints = new HashMap<Integer, double[]>();
        int numLabels = trainingData.getTargetAlphabet().size();
        Iterator<Integer> keyIter = labeledFeatures.keySet().iterator();
        double[][] featureCounts = new double[labeledFeatures.size()][numLabels];
        for (int ii = 0; ii < trainingData.size(); ++ii) {
            Instance instance = (Instance)trainingData.get(ii);
            FeatureVector fv = (FeatureVector)instance.getData();
            Labeling labeling = ((Instance)trainingData.get(ii)).getLabeling();
            double[] labelDist = new double[numLabels];
            if (labeling == null) {
                FeatureConstraintUtil.labelByVoting(labeledFeatures, instance, labelDist);
            } else {
                int li = labeling.getBestIndex();
                labelDist[li] = 1.0;
            }
            keyIter = labeledFeatures.keySet().iterator();
            int i = 0;
            while (keyIter.hasNext()) {
                int fi = keyIter.next();
                if (fv.location(fi) >= 0) {
                    for (int li = 0; li < numLabels; ++li) {
                        double[] dArray = featureCounts[i];
                        int n = li;
                        dArray[n] = dArray[n] + labelDist[li] * fv.valueAtLocation(fv.location(fi));
                    }
                }
                ++i;
            }
        }
        keyIter = labeledFeatures.keySet().iterator();
        int i = 0;
        while (keyIter.hasNext()) {
            int fi = keyIter.next();
            MatrixOps.plusEquals(featureCounts[i], 1.0E-8);
            MatrixOps.timesEquals(featureCounts[i], 1.0 / MatrixOps.sum(featureCounts[i]));
            constraints.put(fi, featureCounts[i]);
            ++i;
        }
        return constraints;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features, boolean reject) {
        HashMap<Integer, ArrayList<Integer>> labeledFeatures = new HashMap<Integer, ArrayList<Integer>>();
        double[][] featureLabelCounts = FeatureConstraintUtil.getFeatureLabelCounts(list, true);
        int numLabels = list.getTargetAlphabet().size();
        int minRank = 100 * numLabels;
        InfoGain infogain = new InfoGain(list);
        double sum = 0.0;
        for (int rank = 0; rank < minRank; ++rank) {
            sum += infogain.getValueAtRank(rank);
        }
        double mean = sum / (double)minRank;
        for (int i = 0; i < features.size(); ++i) {
            int fi = features.get(i);
            if (reject && infogain.value(fi) < mean) {
                logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
                continue;
            }
            double[] prob = featureLabelCounts[fi];
            MatrixOps.plusEquals(prob, 1.0E-8);
            MatrixOps.timesEquals(prob, 1.0 / MatrixOps.sum(prob));
            int[] sortedIndices = FeatureConstraintUtil.getMaxIndices(prob);
            ArrayList<Integer> labels = new ArrayList<Integer>();
            if (numLabels > 2) {
                boolean discard = false;
                double threshold = prob[sortedIndices[0]] / 2.0;
                for (int li = 0; li < numLabels; ++li) {
                    if (prob[li] > threshold) {
                        labels.add(li);
                    }
                    if (!reject || labels.size() <= numLabels / 2) continue;
                    logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
                    discard = true;
                    break;
                }
                if (discard) {
                    continue;
                }
            } else {
                labels.add(sortedIndices[0]);
            }
            labeledFeatures.put(fi, labels);
        }
        return labeledFeatures;
    }

    public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features) {
        return FeatureConstraintUtil.labelFeatures(list, features, true);
    }

    private static double[][] getFeatureLabelCounts(InstanceList list, boolean useValues) {
        int numFeatures = list.getDataAlphabet().size();
        int numLabels = list.getTargetAlphabet().size();
        double[][] featureLabelCounts = new double[numFeatures][numLabels];
        for (int ii = 0; ii < list.size(); ++ii) {
            Instance instance = (Instance)list.get(ii);
            FeatureVector featureVector = (FeatureVector)instance.getData();
            for (int li = 0; li < numLabels; ++li) {
                double py = instance.getLabeling().value(li);
                for (int loc = 0; loc < featureVector.numLocations(); ++loc) {
                    int fi = featureVector.indexAtLocation(loc);
                    double val = useValues ? featureVector.valueAtLocation(loc) : 1.0;
                    double[] dArray = featureLabelCounts[fi];
                    int n = li;
                    dArray[n] = dArray[n] + py * val;
                }
            }
        }
        return featureLabelCounts;
    }

    private static double[] getHeuristicPrior(ArrayList<Integer> labeledFeatures, int numLabels, double majorityProb) {
        int numIndices = labeledFeatures.size();
        double[] dist = new double[numLabels];
        if (numIndices == numLabels) {
            for (int i = 0; i < dist.length; ++i) {
                dist[i] = 1.0 / (double)numLabels;
            }
            return dist;
        }
        double keywordProb = majorityProb / (double)numIndices;
        double otherProb = (1.0 - majorityProb) / (double)(numLabels - numIndices);
        for (int i = 0; i < labeledFeatures.size(); ++i) {
            int li = labeledFeatures.get(i);
            dist[li] = keywordProb;
        }
        for (int li = 0; li < numLabels; ++li) {
            if (dist[li] != 0.0) continue;
            dist[li] = otherProb;
        }
        assert (Maths.almostEquals(MatrixOps.sum(dist), 1.0));
        return dist;
    }

    private static void labelByVoting(HashMap<Integer, ArrayList<Integer>> labeledFeatures, Instance instance, double[] scores) {
        FeatureVector fv = (FeatureVector)instance.getData();
        int numFeatures = instance.getDataAlphabet().size() + 1;
        int[] numLabels = new int[instance.getTargetAlphabet().size()];
        Iterator<Integer> keyIterator = labeledFeatures.keySet().iterator();
        while (keyIterator.hasNext()) {
            ArrayList<Integer> majorityClassList = labeledFeatures.get(keyIterator.next());
            for (int i = 0; i < majorityClassList.size(); ++i) {
                int li;
                int n = li = majorityClassList.get(i).intValue();
                numLabels[n] = numLabels[n] + 1;
            }
        }
        for (int next : labeledFeatures.keySet()) {
            assert (next < numFeatures);
            int loc = fv.location(next);
            if (loc < 0) continue;
            ArrayList<Integer> majorityClassList = labeledFeatures.get(next);
            for (int i = 0; i < majorityClassList.size(); ++i) {
                int li;
                int n = li = majorityClassList.get(i).intValue();
                scores[n] = scores[n] + 1.0;
            }
        }
        double sum = MatrixOps.sum(scores);
        if (sum == 0.0) {
            MatrixOps.plusEquals(scores, 1.0);
            sum = MatrixOps.sum(scores);
        }
        int li = 0;
        while (li < scores.length) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    private static int[] getMaxIndices(double[] x) {
        ArrayList<Element> list = new ArrayList<Element>();
        for (int i = 0; i < x.length; ++i) {
            Element element = new Element(i, x[i]);
            list.add(element);
        }
        Collections.sort(list);
        Collections.reverse(list);
        int[] sortedIndices = new int[x.length];
        for (int i = 0; i < x.length; ++i) {
            sortedIndices[i] = ((Element)list.get(i)).index;
        }
        return sortedIndices;
    }

    private static class Element
    implements Comparable<Element> {
        private int index;
        private double value;

        public Element(int index, double value) {
            this.index = index;
            this.value = value;
        }

        @Override
        public int compareTo(Element element) {
            return Double.compare(this.value, element.value);
        }
    }
}

