/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.io.Serializable;
import java.text.DecimalFormat;

public class MEMM
extends CRF
implements Serializable {
    public MEMM(Pipe inputPipe, Pipe outputPipe) {
        super(inputPipe, outputPipe);
    }

    public MEMM(Alphabet inputAlphabet, Alphabet outputAlphabet) {
        super(inputAlphabet, outputAlphabet);
    }

    public MEMM(CRF crf) {
        super(crf);
    }

    protected CRF.State newState(String name, int index, double initialWeight, double finalWeight, String[] destinationNames, String[] labelNames, String[][] weightNames, CRF crf) {
        return new State(name, index, initialWeight, finalWeight, destinationNames, labelNames, weightNames, crf);
    }

    protected static class TransitionIterator
    extends CRF.TransitionIterator
    implements Serializable {
        private double sum;

        public TransitionIterator(State source, FeatureVectorSequence inputSeq, int inputPosition, String output, CRF memm) {
            super(source, inputSeq, inputPosition, output, memm);
            this.normalizeCosts();
        }

        public TransitionIterator(State source, FeatureVector fv, String output, CRF memm) {
            super(source, fv, output, memm);
            this.normalizeCosts();
        }

        private void normalizeCosts() {
            int i;
            this.sum = Double.NEGATIVE_INFINITY;
            for (i = 0; i < this.weights.length; ++i) {
                this.sum = Transducer.sumLogProb(this.sum, this.weights[i]);
            }
            assert (!Double.isNaN(this.sum));
            if (!Double.isInfinite(this.sum)) {
                for (i = 0; i < this.weights.length; ++i) {
                    this.weights[i] = this.sum;
                }
            }
        }

        public String describeTransition(double cutoff) {
            DecimalFormat f = new DecimalFormat("0.###");
            return super.describeTransition(cutoff) + "Log Z = " + f.format(this.sum) + "\n";
        }
    }

    public static class State
    extends CRF.State
    implements Serializable {
        InstanceList trainingSet;

        protected State(String name, int index, double initialCost, double finalCost, String[] destinationNames, String[] labelNames, String[][] weightNames, CRF crf) {
            super(name, index, initialCost, finalCost, destinationNames, labelNames, weightNames, crf);
        }

        public Transducer.TransitionIterator transitionIterator(Sequence inputSequence, int inputPosition, Sequence outputSequence, int outputPosition) {
            if (inputPosition < 0 || outputPosition < 0) {
                throw new UnsupportedOperationException("Epsilon transitions not implemented.");
            }
            if (inputSequence == null) {
                throw new UnsupportedOperationException("CRFs are not generative models; must have an input sequence.");
            }
            return new TransitionIterator(this, (FeatureVectorSequence)inputSequence, inputPosition, outputSequence == null ? null : (String)outputSequence.get(outputPosition), this.crf);
        }
    }
}

