/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.CRFOptimizableByGE;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.logging.Logger;

public class CRFTrainerByGE
extends TransducerTrainer {
    private static Logger logger = MalletLogger.getLogger(CRFTrainerByGE.class.getName());
    private static final int DEFAULT_NUM_RESETS = 1;
    private static final int DEFAULT_GPV = 10;
    private boolean converged = false;
    private int iteration = 0;
    private int numThreads;
    private int numResets;
    private double gaussianPriorVariance;
    private ArrayList<GEConstraint> constraints;
    private CRF crf;
    private StateLabelMap stateLabelMap;

    public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints) {
        this(crf, constraints, 1);
    }

    public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> constraints, int numThreads) {
        this.constraints = constraints;
        this.crf = crf;
        this.numThreads = numThreads;
        this.numResets = 1;
        this.gaussianPriorVariance = 10.0;
        this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(), true);
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    @Override
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setGaussianPriorVariance(double gpv) {
        this.gaussianPriorVariance = gpv;
    }

    public void setNumResets(int numResets) {
        this.numResets = numResets;
    }

    public void setStateLabelMap(StateLabelMap map) {
        this.stateLabelMap = map;
    }

    @Override
    public boolean train(InstanceList unlabeledSet, int numIterations) {
        assert (this.constraints.size() > 0);
        if (this.constraints.size() == 0) {
            throw new RuntimeException("No constraints specified!");
        }
        CRFOptimizableByGE ge = new CRFOptimizableByGE(this.crf, this.constraints, unlabeledSet, this.stateLabelMap, this.numThreads);
        ge.setGaussianPriorVariance(this.gaussianPriorVariance);
        LimitedMemoryBFGS bfgs = new LimitedMemoryBFGS(ge);
        this.converged = false;
        logger.info("CRF about to train with " + numIterations + " iterations");
        int iter = 0;
        for (int reset = 0; reset < this.numResets + 1; ++reset) {
            while (iter < numIterations) {
                try {
                    this.converged = bfgs.optimize(1);
                    ++this.iteration;
                    logger.info("CRF finished one iteration of maximizer, i=" + iter);
                    this.runEvaluators();
                }
                catch (IllegalArgumentException e) {
                    e.printStackTrace();
                    logger.info("Catching exception; saying converged.");
                    this.converged = true;
                }
                catch (Exception e) {
                    e.printStackTrace();
                    logger.info("Catching exception; saying converged.");
                    this.converged = true;
                }
                if (this.converged) {
                    logger.info("CRF training has converged, i=" + iter);
                    break;
                }
                ++iter;
            }
            bfgs.reset();
        }
        ge.shutdown();
        return this.converged;
    }
}

