/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntOptimizableByGE;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MaxEntGETrainer
extends ClassifierTrainer<MaxEnt>
implements ClassifierTrainer.ByOptimization<MaxEnt>,
Boostable,
Serializable {
    private static final long serialVersionUID = 1L;
    private static Logger logger = MalletLogger.getLogger(MaxEntGETrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntGETrainer.class.getName() + "-pl");
    private int numIterations = Integer.MAX_VALUE;
    private double temperature = 1.0;
    private double gaussianPriorVariance = 1.0;
    private String constraintsFile;
    private HashMap<Integer, double[]> refDist;
    private InstanceList trainingList = null;
    private MaxEnt classifier = null;
    private MaxEntOptimizableByGE ge = null;
    private Optimizer opt = null;

    private boolean testConstraintsFileIndexBased(String filename) {
        File file = new File(filename);
        String firstLine = "";
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            firstLine = reader.readLine();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return !firstLine.contains(":");
    }

    private void readConstraintsFromFile(String filename) {
        this.refDist = new HashMap();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                String featureName = split[0];
                int featureIndex = this.trainingList.getDataAlphabet().lookupIndex(featureName, false);
                assert (split.length - 1 == this.trainingList.getTargetAlphabet().size());
                double[] probs = new double[split.length - 1];
                for (int index = 1; index < split.length; ++index) {
                    double prob;
                    String[] labelSplit = split[index].split(":");
                    int li = this.trainingList.getTargetAlphabet().lookupIndex(labelSplit[0], false);
                    probs[li] = prob = Double.parseDouble(labelSplit[1]);
                }
                this.refDist.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    private void readConstraintsFromFileIndex(String filename) {
        this.refDist = new HashMap();
        File file = new File(filename);
        try {
            BufferedReader reader = new BufferedReader(new FileReader(file));
            String line = reader.readLine();
            while (line != null) {
                String[] split = line.split("\\s+");
                int featureIndex = Integer.parseInt(split[0]);
                assert (split.length - 1 == this.trainingList.getTargetAlphabet().size());
                double[] probs = new double[split.length - 1];
                for (int index = 1; index < split.length; ++index) {
                    double prob;
                    probs[index - 1] = prob = Double.parseDouble(split[index]);
                }
                this.refDist.put(featureIndex, probs);
                line = reader.readLine();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    public MaxEntGETrainer() {
    }

    public MaxEntGETrainer(HashMap<Integer, double[]> refDist) {
        this.refDist = refDist;
    }

    public MaxEntGETrainer(HashMap<Integer, double[]> refDist, MaxEnt classifier) {
        this.refDist = refDist;
        this.classifier = classifier;
    }

    public void setConstraintsFile(String filename) {
        this.constraintsFile = filename;
    }

    public void setTemperature(double temp) {
        this.temperature = temp;
    }

    public void setGaussianPriorVariance(double variance) {
        this.gaussianPriorVariance = variance;
    }

    @Override
    public MaxEnt getClassifier() {
        return this.classifier;
    }

    public Optimizable getOptimizable() {
        return this.ge;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.opt;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    @Override
    public int getIteration() {
        if (this.ge == null) {
            return 0;
        }
        return Integer.MAX_VALUE;
    }

    @Override
    public MaxEnt train(InstanceList trainingList) {
        return this.train(trainingList, this.numIterations);
    }

    @Override
    public MaxEnt train(InstanceList train, int numIterations) {
        this.trainingList = train;
        if (this.refDist == null && this.constraintsFile != null) {
            if (this.testConstraintsFileIndexBased(this.constraintsFile)) {
                this.readConstraintsFromFileIndex(this.constraintsFile);
            } else {
                this.readConstraintsFromFile(this.constraintsFile);
            }
        }
        this.ge = new MaxEntOptimizableByGE(this.trainingList, this.refDist, this.classifier);
        this.ge.setTemperature(this.temperature);
        this.ge.setGaussianPriorVariance(this.gaussianPriorVariance);
        this.opt = new LimitedMemoryBFGS(this.ge);
        logger.fine("trainingList.size() = " + this.trainingList.size());
        for (int i = 0; i < numIterations; ++i) {
            boolean converged;
            try {
                converged = this.opt.optimize(1);
            }
            catch (Exception e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                converged = true;
            }
            if (converged) break;
        }
        if (numIterations == Integer.MAX_VALUE) {
            this.opt = new LimitedMemoryBFGS(this.ge);
            try {
                this.opt.optimize();
            }
            catch (Exception e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
            }
        }
        progressLogger.info("\n");
        this.classifier = this.ge.getClassifier();
        return this.classifier;
    }
}

