/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByGE;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

public class MaxEntOptimizableByKLGE
extends MaxEntOptimizableByGE {
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByKLGE.class.getName() + "-pl");

    public MaxEntOptimizableByKLGE(InstanceList trainingList, HashMap<Integer, double[]> constraints, MaxEnt initClassifier) {
        super(trainingList, constraints, initClassifier);
    }

    @Override
    public double getValue() {
        if (!this.cacheStale) {
            return this.cachedValue;
        }
        if (this.objWeight == 0.0) {
            return 0.0;
        }
        Arrays.fill(this.cachedGradient, 0.0);
        int numRefDist = this.constraints.size();
        int numFeatures = this.trainingList.getDataAlphabet().size() + 1;
        int numLabels = this.trainingList.getTargetAlphabet().size();
        double scalingFactor = this.objWeight;
        if (this.mapping == null) {
            this.setMapping();
        }
        double[][] modelExpectations = new double[numRefDist][numLabels];
        double[][] ratio = new double[numRefDist][numLabels];
        double[] featureCounts = new double[numRefDist];
        double[][] scores = new double[this.trainingList.size()][numLabels];
        double[] constraintValue = new double[numLabels];
        for (int ii = 0; ii < this.trainingList.size(); ++ii) {
            int cIndex;
            Instance instance = (Instance)this.trainingList.get(ii);
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            if (instance.getTarget() != null) continue;
            FeatureVector fv = (FeatureVector)instance.getData();
            this.classifier.getClassificationScoresWithTemperature(instance, this.temperature, scores[ii]);
            for (int loc = 0; loc < fv.numLocations(); ++loc) {
                int featureIndex = fv.indexAtLocation(loc);
                if (!this.constraints.containsKey(featureIndex)) continue;
                int cIndex2 = (Integer)this.mapping.get(featureIndex);
                double val = !this.useValues ? 1.0 : fv.valueAtLocation(loc);
                int n = cIndex2;
                featureCounts[n] = featureCounts[n] + val;
                for (int l = 0; l < numLabels; ++l) {
                    double[] dArray = modelExpectations[cIndex2];
                    int n2 = l;
                    dArray[n2] = dArray[n2] + scores[ii][l] * val * instanceWeight;
                }
            }
            if (!this.constraints.containsKey(this.defaultFeatureIndex)) continue;
            int n = cIndex = ((Integer)this.mapping.get(this.defaultFeatureIndex)).intValue();
            featureCounts[n] = featureCounts[n] + 1.0;
            for (int l = 0; l < numLabels; ++l) {
                double[] dArray = modelExpectations[cIndex];
                int n3 = l;
                dArray[n3] = dArray[n3] + scores[ii][l] * instanceWeight;
            }
        }
        double value = 0.0;
        Iterator i$ = this.constraints.keySet().iterator();
        while (i$.hasNext()) {
            int featureIndex = (Integer)i$.next();
            int cIndex = (Integer)this.mapping.get(featureIndex);
            if (!(featureCounts[cIndex] > 0.0)) continue;
            for (int label = 0; label < numLabels; ++label) {
                double cProb = ((double[])this.constraints.get(featureIndex))[label];
                double[] dArray = modelExpectations[cIndex];
                int n = label;
                dArray[n] = dArray[n] / featureCounts[cIndex];
                ratio[cIndex][label] = cProb / modelExpectations[cIndex][label];
                value += scalingFactor * cProb * Math.log(modelExpectations[cIndex][label]);
                if (!(cProb > 0.0)) continue;
                value -= scalingFactor * cProb * Math.log(cProb);
            }
            assert (Maths.almostEquals(MatrixOps.sum(modelExpectations[cIndex]), 1.0));
        }
        for (int ii = 0; ii < this.trainingList.size(); ++ii) {
            Instance instance = (Instance)this.trainingList.get(ii);
            if (instance.getTarget() != null) continue;
            Arrays.fill(constraintValue, 0.0);
            double instanceExpectation = 0.0;
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            FeatureVector fv = (FeatureVector)instance.getData();
            for (int loc = 0; loc < fv.numLocations() + 1; ++loc) {
                int label;
                int cIndex;
                int featureIndex = loc == fv.numLocations() ? this.defaultFeatureIndex : fv.indexAtLocation(loc);
                if (!this.constraints.containsKey(featureIndex) || featureCounts[cIndex = ((Integer)this.mapping.get(featureIndex)).intValue()] == 0.0) continue;
                double val = featureIndex == this.defaultFeatureIndex || !this.useValues ? 1.0 : fv.valueAtLocation(loc);
                for (label = 0; label < numLabels; ++label) {
                    int n = label;
                    constraintValue[n] = constraintValue[n] + val / featureCounts[cIndex] * ratio[cIndex][label];
                }
                for (label = 0; label < numLabels; ++label) {
                    instanceExpectation += val / featureCounts[cIndex] * ratio[cIndex][label] * scores[ii][label];
                }
            }
            for (int label = 0; label < numLabels; ++label) {
                if (scores[ii][label] == 0.0) continue;
                assert (!Double.isInfinite(scores[ii][label]));
                double weight = scalingFactor * instanceWeight * this.temperature * scores[ii][label] * (constraintValue[label] - instanceExpectation);
                MatrixOps.rowPlusEquals(this.cachedGradient, numFeatures, label, fv, weight);
                int n = numFeatures * label + this.defaultFeatureIndex;
                this.cachedGradient[n] = this.cachedGradient[n] + weight;
            }
        }
        this.cachedValue = value;
        this.cacheStale = false;
        double reg = this.getRegularization();
        progressLogger.info("Value (GE=" + value + " Gaussian prior= " + reg + ") = " + this.cachedValue);
        return value;
    }
}

