/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.MultiSegmentationEvaluator;
import cc.mallet.fst.NoopTransducerTrainer;
import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.ViterbiWriter;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import java.io.Closeable;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;

public class SimpleTagger {
    private static Logger logger = MalletLogger.getLogger(SimpleTagger.class.getName());
    private static final CommandOption.Double gaussianVarianceOption = new CommandOption.Double(SimpleTagger.class, "gaussian-variance", "DECIMAL", true, 10.0, "The gaussian prior variance used for training.", null);
    private static final CommandOption.Boolean trainOption = new CommandOption.Boolean(SimpleTagger.class, "train", "true|false", true, false, "Whether to train", null);
    private static final CommandOption.String testOption = new CommandOption.String(SimpleTagger.class, "test", "lab or seg=start-1.continue-1,...,start-n.continue-n", true, null, "Test measuring labeling or segmentation (start-i, continue-i) accuracy", null);
    private static final CommandOption.File modelOption = new CommandOption.File(SimpleTagger.class, "model-file", "FILENAME", true, null, "The filename for reading (train/run) or saving (train) the model.", null);
    private static final CommandOption.Double trainingFractionOption = new CommandOption.Double(SimpleTagger.class, "training-proportion", "DECIMAL", true, 0.5, "Fraction of data to use for training in a random split.", null);
    private static final CommandOption.Integer randomSeedOption = new CommandOption.Integer(SimpleTagger.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    private static final CommandOption.IntegerArray ordersOption = new CommandOption.IntegerArray(SimpleTagger.class, "orders", "COMMA-SEP-DECIMALS", true, new int[]{1}, "List of label Markov orders (main and backoff) ", null);
    private static final CommandOption.String forbiddenOption = new CommandOption.String(SimpleTagger.class, "forbidden", "REGEXP", true, "\\s", "label1,label2 transition forbidden if it matches this", null);
    private static final CommandOption.String allowedOption = new CommandOption.String(SimpleTagger.class, "allowed", "REGEXP", true, ".*", "label1,label2 transition allowed only if it matches this", null);
    private static final CommandOption.String defaultOption = new CommandOption.String(SimpleTagger.class, "default-label", "STRING", true, "O", "Label for initial context and uninteresting tokens", null);
    private static final CommandOption.Integer iterationsOption = new CommandOption.Integer(SimpleTagger.class, "iterations", "INTEGER", true, 500, "Number of training iterations", null);
    private static final CommandOption.Boolean viterbiOutputOption = new CommandOption.Boolean(SimpleTagger.class, "viterbi-output", "true|false", true, false, "Print Viterbi periodically during training", null);
    private static final CommandOption.Boolean connectedOption = new CommandOption.Boolean(SimpleTagger.class, "fully-connected", "true|false", true, true, "Include all allowed transitions, even those not in training data", null);
    private static final CommandOption.Boolean continueTrainingOption = new CommandOption.Boolean(SimpleTagger.class, "continue-training", "true|false", false, false, "Continue training from model specified by --model-file", null);
    private static final CommandOption.Integer nBestOption = new CommandOption.Integer(SimpleTagger.class, "n-best", "INTEGER", true, 1, "How many answers to output", null);
    private static final CommandOption.Integer cacheSizeOption = new CommandOption.Integer(SimpleTagger.class, "cache-size", "INTEGER", true, 100000, "How much state information to memoize in n-best decoding", null);
    private static final CommandOption.Boolean includeInputOption = new CommandOption.Boolean(SimpleTagger.class, "include-input", "true|false", true, false, "Whether to include the input features when printing decoding output", null);
    private static final CommandOption.Boolean featureInductionOption = new CommandOption.Boolean(SimpleTagger.class, "feature-induction", "true|false", true, false, "Whether to perform feature induction during training", null);
    private static final CommandOption.List commandOptions = new CommandOption.List("Training, testing and running a generic tagger.", new CommandOption[]{gaussianVarianceOption, trainOption, iterationsOption, testOption, trainingFractionOption, modelOption, randomSeedOption, ordersOption, forbiddenOption, allowedOption, defaultOption, viterbiOutputOption, connectedOption, continueTrainingOption, nBestOption, cacheSizeOption, includeInputOption, featureInductionOption});

    private SimpleTagger() {
    }

    public static CRF train(InstanceList training, InstanceList testing, TransducerEvaluator eval, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected, int iterations, double var, CRF crf) {
        int i;
        Pattern forbiddenPat = Pattern.compile(forbidden);
        Pattern allowedPat = Pattern.compile(allowed);
        if (crf == null) {
            crf = new CRF(training.getPipe(), (Pipe)null);
            String startName = crf.addOrderNStates(training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected);
            CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
            crft.setGaussianPriorVariance(var);
            for (i = 0; i < crf.numStates(); ++i) {
                crf.getState(i).setInitialWeight(Double.NEGATIVE_INFINITY);
            }
            crf.getState(startName).setInitialWeight(0.0);
        }
        logger.info("Training on " + training.size() + " instances");
        if (testing != null) {
            logger.info("Testing on " + testing.size() + " instances");
        }
        CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf);
        if (SimpleTagger.featureInductionOption.value) {
            crft.trainWithFeatureInduction(training, null, testing, eval, iterations, 10, 20, 500, 0.5, false, null);
        } else {
            for (i = 1; i <= iterations; ++i) {
                boolean converged = crft.train(training, 1);
                if (i % 1 == 0 && eval != null) {
                    eval.evaluate(crft);
                }
                if (SimpleTagger.viterbiOutputOption.value && i % 10 == 0) {
                    new ViterbiWriter("", new InstanceList[]{training, testing}, new String[]{"training", "testing"}).evaluate(crft);
                }
                if (converged) break;
            }
        }
        return crf;
    }

    public static void test(TransducerTrainer tt, TransducerEvaluator eval, InstanceList testing) {
        eval.evaluateInstanceList(tt, testing, "Testing");
    }

    public static Sequence[] apply(Transducer model, Sequence input, int k) {
        Sequence[] answers;
        if (k == 1) {
            answers = new Sequence[]{model.transduce(input)};
        } else {
            MaxLatticeDefault lattice = new MaxLatticeDefault(model, input, null, cacheSizeOption.value());
            answers = lattice.bestOutputSequences(k).toArray(new Sequence[0]);
        }
        return answers;
    }

    public static void main(String[] args) throws Exception {
        Closeable s;
        FileReader trainingFile = null;
        FileReader testFile = null;
        InstanceList trainingData = null;
        InstanceList testData = null;
        boolean numEvaluations = false;
        int iterationsBetweenEvals = 16;
        int restArgs = commandOptions.processOptions(args);
        if (restArgs == args.length) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Missing data file(s)");
        }
        if (SimpleTagger.trainOption.value) {
            trainingFile = new FileReader(new File(args[restArgs]));
            if (SimpleTagger.testOption.value != null && restArgs < args.length - 1) {
                testFile = new FileReader(new File(args[restArgs + 1]));
            }
        } else {
            testFile = new FileReader(new File(args[restArgs]));
        }
        Pipe p = null;
        CRF crf = null;
        TransducerEvaluator eval = null;
        if (SimpleTagger.continueTrainingOption.value || !SimpleTagger.trainOption.value) {
            if (SimpleTagger.modelOption.value == null) {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Missing model file option");
            }
            s = new ObjectInputStream(new FileInputStream(SimpleTagger.modelOption.value));
            crf = (CRF)((ObjectInputStream)s).readObject();
            ((ObjectInputStream)s).close();
            p = crf.getInputPipe();
        } else {
            p = new SimpleTaggerSentence2FeatureVectorSequence();
            p.getTargetAlphabet().lookupIndex(SimpleTagger.defaultOption.value);
        }
        if (SimpleTagger.trainOption.value) {
            p.setTargetProcessing(true);
            trainingData = new InstanceList(p);
            trainingData.addThruPipe(new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true));
            logger.info("Number of features in training data: " + p.getDataAlphabet().size());
            if (SimpleTagger.testOption.value != null) {
                if (testFile != null) {
                    testData = new InstanceList(p);
                    testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
                } else {
                    Random r = new Random(SimpleTagger.randomSeedOption.value);
                    InstanceList[] trainingLists = trainingData.split(r, new double[]{SimpleTagger.trainingFractionOption.value, 1.0 - SimpleTagger.trainingFractionOption.value});
                    trainingData = trainingLists[0];
                    testData = trainingLists[1];
                }
            }
        } else if (SimpleTagger.testOption.value != null) {
            p.setTargetProcessing(true);
            testData = new InstanceList(p);
            testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        } else {
            p.setTargetProcessing(false);
            testData = new InstanceList(p);
            testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        }
        logger.info("Number of predicates: " + p.getDataAlphabet().size());
        if (SimpleTagger.testOption.value != null) {
            if (SimpleTagger.testOption.value.startsWith("lab")) {
                eval = new TokenAccuracyEvaluator(new InstanceList[]{trainingData, testData}, new String[]{"Training", "Testing"});
            } else if (SimpleTagger.testOption.value.startsWith("seg=")) {
                String[] pairs = SimpleTagger.testOption.value.substring(4).split(",");
                if (pairs.length < 1) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Missing segment start/continue labels: " + SimpleTagger.testOption.value);
                }
                Object[] startTags = new String[pairs.length];
                Object[] continueTags = new String[pairs.length];
                for (int i = 0; i < pairs.length; ++i) {
                    String[] pair = pairs[i].split("\\.");
                    if (pair.length != 2) {
                        commandOptions.printUsage(true);
                        throw new IllegalArgumentException("Incorrectly-specified segment start and end labels: " + pairs[i]);
                    }
                    startTags[i] = pair[0];
                    continueTags[i] = pair[1];
                }
                eval = new MultiSegmentationEvaluator(new InstanceList[]{trainingData, testData}, new String[]{"Training", "Testing"}, startTags, continueTags);
            } else {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Invalid test option: " + SimpleTagger.testOption.value);
            }
        }
        if (p.isTargetProcessing()) {
            Alphabet targets = p.getTargetAlphabet();
            StringBuffer buf = new StringBuffer("Labels:");
            for (int i = 0; i < targets.size(); ++i) {
                buf.append(" ").append(targets.lookupObject(i).toString());
            }
            logger.info(buf.toString());
        }
        if (SimpleTagger.trainOption.value) {
            crf = SimpleTagger.train(trainingData, testData, eval, SimpleTagger.ordersOption.value, SimpleTagger.defaultOption.value, SimpleTagger.forbiddenOption.value, SimpleTagger.allowedOption.value, SimpleTagger.connectedOption.value, SimpleTagger.iterationsOption.value, SimpleTagger.gaussianVarianceOption.value, crf);
            if (SimpleTagger.modelOption.value != null) {
                s = new ObjectOutputStream(new FileOutputStream(SimpleTagger.modelOption.value));
                ((ObjectOutputStream)s).writeObject(crf);
                ((ObjectOutputStream)s).close();
            }
        } else {
            if (crf == null) {
                if (SimpleTagger.modelOption.value == null) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Missing model file option");
                }
                s = new ObjectInputStream(new FileInputStream(SimpleTagger.modelOption.value));
                crf = (CRF)((ObjectInputStream)s).readObject();
                ((ObjectInputStream)s).close();
            }
            if (eval != null) {
                SimpleTagger.test(new NoopTransducerTrainer(crf), eval, testData);
            } else {
                boolean includeInput = includeInputOption.value();
                for (int i = 0; i < testData.size(); ++i) {
                    Sequence input = (Sequence)((Instance)testData.get(i)).getData();
                    Sequence[] outputs = SimpleTagger.apply(crf, input, SimpleTagger.nBestOption.value);
                    int k = outputs.length;
                    boolean error = false;
                    for (int a = 0; a < k; ++a) {
                        if (outputs[a].size() == input.size()) continue;
                        System.err.println("Failed to decode input sequence " + i + ", answer " + a);
                        error = true;
                    }
                    if (error) continue;
                    for (int j = 0; j < input.size(); ++j) {
                        StringBuffer buf = new StringBuffer();
                        for (int a = 0; a < k; ++a) {
                            buf.append(outputs[a].get(j).toString()).append(" ");
                        }
                        if (includeInput) {
                            FeatureVector fv = (FeatureVector)input.get(j);
                            buf.append(fv.toString(true));
                        }
                        System.out.println(buf.toString());
                    }
                    System.out.println();
                }
            }
        }
    }

    public static class SimpleTaggerSentence2FeatureVectorSequence
    extends Pipe {
        public SimpleTaggerSentence2FeatureVectorSequence() {
            super(new Alphabet(), new LabelAlphabet());
        }

        private String[][] parseSentence(String sentence) {
            String[] lines = sentence.split("\n");
            String[][] tokens = new String[lines.length][];
            for (int i = 0; i < lines.length; ++i) {
                tokens[i] = lines[i].split(" ");
            }
            return tokens;
        }

        public Instance pipe(Instance carrier) {
            String[][] tokens;
            Object inputData = carrier.getData();
            Alphabet features = this.getDataAlphabet();
            LabelSequence target = null;
            if (inputData instanceof String) {
                tokens = this.parseSentence((String)inputData);
            } else if (inputData instanceof String[][]) {
                tokens = (String[][])inputData;
            } else {
                throw new IllegalArgumentException("Not a String or String[][]; got " + inputData);
            }
            FeatureVector[] fvs = new FeatureVector[tokens.length];
            if (this.isTargetProcessing()) {
                LabelAlphabet labels = (LabelAlphabet)this.getTargetAlphabet();
                target = new LabelSequence(labels, tokens.length);
            }
            for (int l = 0; l < tokens.length; ++l) {
                int nFeatures;
                if (this.isTargetProcessing()) {
                    if (tokens[l].length < 1) {
                        throw new IllegalStateException("Missing label at line " + l + " instance " + carrier.getName());
                    }
                    nFeatures = tokens[l].length - 1;
                    target.add(tokens[l][nFeatures]);
                } else {
                    nFeatures = tokens[l].length;
                }
                int[] featureIndices = new int[nFeatures];
                for (int f = 0; f < nFeatures; ++f) {
                    featureIndices[f] = features.lookupIndex(tokens[l][f]);
                }
                fvs[l] = featureInductionOption.value ? new AugmentableFeatureVector(features, featureIndices, null, featureIndices.length) : new FeatureVector(features, featureIndices);
            }
            carrier.setData(new FeatureVectorSequence(fvs));
            if (this.isTargetProcessing()) {
                carrier.setTarget(target);
            } else {
                carrier.setTarget(new LabelSequence(this.getTargetAlphabet()));
            }
            return carrier;
        }
    }
}

