/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import java.util.HashMap;

public abstract class MaxEntOptimizableByGE
implements Optimizable.ByGradientValue {
    protected boolean cacheStale = true;
    protected boolean useValues = false;
    protected int defaultFeatureIndex;
    protected double temperature = 1.0;
    protected double objWeight = 1.0;
    protected double cachedValue;
    protected double gaussianPriorVariance = 1.0;
    protected double[] cachedGradient;
    protected double[] parameters;
    protected InstanceList trainingList;
    protected MaxEnt classifier;
    protected HashMap<Integer, double[]> constraints;
    protected HashMap<Integer, Integer> mapping;

    public MaxEntOptimizableByGE(InstanceList trainingList, HashMap<Integer, double[]> constraints, MaxEnt initClassifier) {
        int numFeatures;
        this.trainingList = trainingList;
        this.defaultFeatureIndex = numFeatures = trainingList.getDataAlphabet().size();
        int numLabels = trainingList.getTargetAlphabet().size();
        this.cachedGradient = new double[(numFeatures + 1) * numLabels];
        this.cachedValue = 0.0;
        if (initClassifier != null) {
            this.parameters = initClassifier.parameters;
            this.classifier = initClassifier;
        } else {
            this.parameters = new double[(numFeatures + 1) * numLabels];
            this.classifier = new MaxEnt(trainingList.getPipe(), this.parameters);
        }
        this.constraints = constraints;
    }

    public void setGaussianPriorVariance(double variance) {
        this.gaussianPriorVariance = variance;
    }

    public void setTemperature(double temp) {
        this.temperature = temp;
    }

    public void setWeight(double weight) {
        this.objWeight = weight;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override
    public abstract double getValue();

    protected double getRegularization() {
        double regularization = 0.0;
        int pi = 0;
        while (pi < this.parameters.length) {
            double p = this.parameters[pi];
            regularization -= p * p / (2.0 * this.gaussianPriorVariance);
            int n = pi++;
            this.cachedGradient[n] = this.cachedGradient[n] - p / this.gaussianPriorVariance;
        }
        this.cachedValue += regularization;
        return regularization;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cacheStale) {
            this.getValue();
        }
        assert (buffer.length == this.cachedGradient.length);
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.cachedGradient[i];
        }
    }

    public void setUseValues(boolean flag) {
        this.useValues = flag;
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void getParameters(double[] buffer) {
        assert (buffer.length == this.parameters.length);
        System.arraycopy(this.parameters, 0, buffer, 0, buffer.length);
    }

    @Override
    public void setParameter(int index, double value) {
        this.cacheStale = true;
        this.parameters[index] = value;
    }

    @Override
    public void setParameters(double[] params) {
        assert (params.length == this.parameters.length);
        this.cacheStale = true;
        System.arraycopy(params, 0, this.parameters, 0, this.parameters.length);
    }

    protected void setMapping() {
        int cCounter = 0;
        this.mapping = new HashMap();
        for (int featureIndex : this.constraints.keySet()) {
            this.mapping.put(featureIndex, cCounter);
            ++cCounter;
        }
    }
}

