/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MCMaxEntTrainer;
import cc.mallet.classify.NaiveBayes;
import cc.mallet.classify.NaiveBayesTrainer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Multinomial;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NaiveBayesEMTrainer
extends ClassifierTrainer<NaiveBayes> {
    private static Logger logger = MalletLogger.getLogger(MCMaxEntTrainer.class.getName());
    Multinomial.Estimator featureEstimator = new Multinomial.LaplaceEstimator();
    Multinomial.Estimator priorEstimator = new Multinomial.LaplaceEstimator();
    double docLengthNormalization = -1.0;
    double unlabeledDataWeight = 1.0;
    int iteration = 0;
    NaiveBayesTrainer.Factory nbTrainer = new NaiveBayesTrainer.Factory();
    NaiveBayes classifier;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public NaiveBayesEMTrainer() {
        this.nbTrainer.setDocLengthNormalization(this.docLengthNormalization);
        this.nbTrainer.setFeatureMultinomialEstimator(this.featureEstimator);
        this.nbTrainer.setPriorMultinomialEstimator(this.priorEstimator);
    }

    public Multinomial.Estimator getFeatureMultinomialEstimator() {
        return this.featureEstimator;
    }

    public void setFeatureMultinomialEstimator(Multinomial.Estimator me) {
        this.featureEstimator = me;
    }

    public Multinomial.Estimator getPriorMultinomialEstimator() {
        return this.priorEstimator;
    }

    public void setPriorMultinomialEstimator(Multinomial.Estimator me) {
        this.priorEstimator = me;
    }

    public void setDocLengthNormalization(double d) {
        this.docLengthNormalization = d;
    }

    public double getDocLengthNormalization() {
        return this.docLengthNormalization;
    }

    public double getUnlabeledDataWeight() {
        return this.unlabeledDataWeight;
    }

    public void setUnlabeledDataWeight(double unlabeledDataWeight) {
        this.unlabeledDataWeight = unlabeledDataWeight;
    }

    public int getIteration() {
        return this.iteration;
    }

    @Override
    public boolean isFinishedTraining() {
        return false;
    }

    @Override
    public NaiveBayes getClassifier() {
        return this.classifier;
    }

    @Override
    public NaiveBayes train(InstanceList trainingSet) {
        NaiveBayes c = ((NaiveBayesTrainer)this.nbTrainer.newClassifierTrainer()).train(trainingSet);
        double prevLogLikelihood = 0.0;
        double logLikelihood = 0.0;
        boolean converged = false;
        int iteration = 0;
        while (!converged) {
            InstanceList trainingSet2 = new InstanceList(trainingSet.getPipe());
            for (int ii = 0; ii < trainingSet.size(); ++ii) {
                Instance inst = (Instance)trainingSet.get(ii);
                if (inst.getLabeling() != null) {
                    trainingSet2.add(inst, 1.0);
                    continue;
                }
                Instance inst2 = inst.shallowCopy();
                inst2.unLock();
                inst2.setLabeling(c.classify(inst).getLabeling());
                inst2.lock();
                trainingSet2.add(inst2, this.unlabeledDataWeight);
            }
            c = ((NaiveBayesTrainer)this.nbTrainer.newClassifierTrainer()).train(trainingSet2);
            logLikelihood = c.dataLogLikelihood(trainingSet2);
            System.err.println("Loglikelihood = " + logLikelihood);
            if (Math.abs((logLikelihood - prevLogLikelihood) / logLikelihood) < 1.0E-4) {
                converged = true;
            }
            prevLogLikelihood = logLikelihood;
            ++iteration;
        }
        return c;
    }

    public String toString() {
        String ret = "NaiveBayesEMTrainer";
        if (this.docLengthNormalization != 1.0) {
            ret = ret + ",docLengthNormalization=" + this.docLengthNormalization;
        }
        if (this.unlabeledDataWeight != 1.0) {
            ret = ret + ",unlabeledDataWeight=" + this.unlabeledDataWeight;
        }
        return ret;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.featureEstimator);
        out.writeObject(this.priorEstimator);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayesTrainer versions: wanted 1, got " + version);
        }
        this.featureEstimator = (Multinomial.Estimator)in.readObject();
        this.priorEstimator = (Multinomial.Estimator)in.readObject();
    }
}

