/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByLabelLikelihood;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MaxEntOptimizableByGE
implements Optimizable.ByGradientValue {
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName() + "-pl");
    private boolean cacheStale = true;
    private boolean useValues = false;
    private int defaultFeatureIndex;
    private double temperature = 1.0;
    private double objWeight = 1.0;
    private double cachedValue;
    private double gaussianPriorVariance;
    private double[] cachedGradient;
    private double[] parameters;
    private InstanceList trainingList;
    private MaxEnt classifier;
    private HashMap<Integer, double[]> refEx;
    private HashMap<Integer, Integer> mapping;

    public MaxEntOptimizableByGE(InstanceList trainingList, HashMap<Integer, double[]> refDist, MaxEnt classifier) {
        int numFeatures;
        this.trainingList = trainingList;
        this.defaultFeatureIndex = numFeatures = trainingList.getDataAlphabet().size();
        int numLabels = trainingList.getTargetAlphabet().size();
        this.parameters = new double[(numFeatures + 1) * numLabels];
        this.cachedGradient = new double[(numFeatures + 1) * numLabels];
        this.cachedValue = 0.0;
        this.classifier = classifier != null ? classifier : new MaxEnt(trainingList.getPipe(), this.parameters);
        this.refEx = refDist;
    }

    public void setGaussianPriorVariance(double variance) {
        this.gaussianPriorVariance = variance;
    }

    public void setTemperature(double temp) {
        this.temperature = temp;
    }

    public void setWeight(double weight) {
        this.objWeight = weight;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override
    public double getValue() {
        if (!this.cacheStale) {
            return this.cachedValue;
        }
        if (this.objWeight == 0.0) {
            return 0.0;
        }
        Arrays.fill(this.cachedGradient, 0.0);
        int numRefDist = this.refEx.size();
        int numFeatures = this.trainingList.getDataAlphabet().size() + 1;
        int numLabels = this.trainingList.getTargetAlphabet().size();
        double scalingFactor = this.objWeight;
        if (this.mapping == null) {
            this.setMapping();
        }
        double[][] modelExScores = new double[numRefDist][numLabels];
        double[][] modelExDists = new double[numRefDist][numLabels];
        double[][] ratio = new double[numRefDist][numLabels];
        double[] featureCounts = new double[numRefDist];
        double[][] scores = new double[this.trainingList.size()][numLabels];
        Iterator iter = this.trainingList.iterator();
        int ii = 0;
        while (iter.hasNext()) {
            Instance instance = (Instance)iter.next();
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            if (instance.getTarget() != null) {
                ++ii;
                continue;
            }
            FeatureVector fv = (FeatureVector)instance.getData();
            this.classifier.getClassificationScoresWithTemperature(instance, this.temperature, scores[ii]);
            for (int loc = 0; loc < fv.numLocations(); ++loc) {
                int featureIndex = fv.indexAtLocation(loc);
                if (!this.refEx.containsKey(featureIndex)) continue;
                int cIndex = this.mapping.get(featureIndex);
                double val = !this.useValues ? 1.0 : fv.valueAtLocation(loc);
                int n = cIndex;
                featureCounts[n] = featureCounts[n] + val;
                for (int l = 0; l < numLabels; ++l) {
                    double[] dArray = modelExScores[cIndex];
                    int n2 = l;
                    dArray[n2] = dArray[n2] + scores[ii][l] * val * instanceWeight;
                }
            }
            if (this.refEx.containsKey(this.defaultFeatureIndex)) {
                int cIndex;
                int n = cIndex = this.mapping.get(this.defaultFeatureIndex).intValue();
                featureCounts[n] = featureCounts[n] + 1.0;
                for (int l = 0; l < numLabels; ++l) {
                    double[] dArray = modelExScores[cIndex];
                    int n3 = l;
                    dArray[n3] = dArray[n3] + scores[ii][l] * instanceWeight;
                }
            }
            ++ii;
        }
        for (int featureIndex : this.refEx.keySet()) {
            int cIndex = this.mapping.get(featureIndex);
            if (!(featureCounts[cIndex] > 0.0)) continue;
            for (int label = 0; label < numLabels; ++label) {
                modelExDists[cIndex][label] = modelExScores[cIndex][label] / featureCounts[cIndex];
                ratio[cIndex][label] = this.refEx.get(featureIndex)[label] / modelExScores[cIndex][label];
            }
            assert (Maths.almostEquals(MatrixOps.sum(modelExDists[cIndex]), 1.0));
        }
        iter = this.trainingList.iterator();
        ii = 0;
        while (iter.hasNext()) {
            Instance instance = (Instance)iter.next();
            if (instance.getTarget() != null) {
                ++ii;
                continue;
            }
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            FeatureVector fv = (FeatureVector)instance.getData();
            for (int loc = 0; loc < fv.numLocations() + 1; ++loc) {
                int label;
                int cIndex;
                int featureIndex = loc == fv.numLocations() ? this.defaultFeatureIndex : fv.indexAtLocation(loc);
                if (!this.refEx.containsKey(featureIndex) || MatrixOps.sum(modelExDists[cIndex = this.mapping.get(featureIndex).intValue()]) == 0.0) continue;
                double val = featureIndex == this.defaultFeatureIndex || !this.useValues ? 1.0 : fv.valueAtLocation(loc);
                double x = 0.0;
                for (label = 0; label < numLabels; ++label) {
                    x += ratio[cIndex][label] * scores[ii][label];
                }
                for (label = 0; label < numLabels; ++label) {
                    if (scores[ii][label] == 0.0) continue;
                    assert (!Double.isInfinite(scores[ii][label]));
                    double weight = scalingFactor * instanceWeight * this.temperature * val * scores[ii][label] * (ratio[cIndex][label] - x);
                    MatrixOps.rowPlusEquals(this.cachedGradient, numFeatures, label, fv, weight);
                    int n = numFeatures * label + this.defaultFeatureIndex;
                    this.cachedGradient[n] = this.cachedGradient[n] + weight;
                }
            }
            ++ii;
        }
        double totalValue = 0.0;
        for (int featureIndex : this.refEx.keySet()) {
            int label;
            int cIndex = this.mapping.get(featureIndex);
            if (MatrixOps.sum(modelExDists[cIndex]) == 0.0) continue;
            double value = 0.0;
            for (label = 0; label < numLabels; ++label) {
                value -= scalingFactor * this.refEx.get(featureIndex)[label] * Math.log(modelExDists[cIndex][label]);
            }
            for (label = 0; label < numLabels; ++label) {
                value += scalingFactor * this.refEx.get(featureIndex)[label] * Math.log(this.refEx.get(featureIndex)[label]);
            }
            totalValue -= value;
        }
        this.cachedValue = totalValue;
        this.cacheStale = false;
        double reg = this.getRegularization();
        progressLogger.info("Value (GE=" + totalValue + " Gaussian prior= " + reg + ") = " + this.cachedValue);
        return totalValue;
    }

    public double getRegularization() {
        double regularization = !Double.isInfinite(this.gaussianPriorVariance) ? Math.log(this.gaussianPriorVariance * Math.sqrt(Math.PI * 2)) : 0.0;
        int pi = 0;
        while (pi < this.parameters.length) {
            double p = this.parameters[pi];
            regularization -= p * p / (2.0 * this.gaussianPriorVariance);
            int n = pi++;
            this.cachedGradient[n] = this.cachedGradient[n] - p / this.gaussianPriorVariance;
        }
        this.cachedValue += regularization;
        return regularization;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cacheStale) {
            this.getValue();
        }
        assert (buffer.length == this.cachedGradient.length);
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.cachedGradient[i];
        }
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void getParameters(double[] buffer) {
        assert (buffer.length == this.parameters.length);
        System.arraycopy(this.parameters, 0, buffer, 0, buffer.length);
    }

    @Override
    public void setParameter(int index, double value) {
        this.cacheStale = true;
        this.parameters[index] = value;
    }

    @Override
    public void setParameters(double[] params) {
        assert (params.length == this.parameters.length);
        this.cacheStale = true;
        System.arraycopy(params, 0, this.parameters, 0, this.parameters.length);
    }

    private void setMapping() {
        int cCounter = 0;
        this.mapping = new HashMap();
        for (int featureIndex : this.refEx.keySet()) {
            this.mapping.put(featureIndex, cCounter);
            ++cCounter;
        }
    }
}

