/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.tui;

import bsh.EvalError;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.NaiveBayesTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.ConfusionMatrix;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.ProgressMessageLogFormatter;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Logger;

public abstract class Vectors2Classify {
    private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
    private static ArrayList<ClassifierTrainer> classifierTrainers = new ArrayList();
    private static boolean[][] ReportOptions = new boolean[3][4];
    private static String[][] ReportOptionArgs = new String[3][4];
    static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings(Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1:label|confusion|raw]", true, new String[]{"test:accuracy", "test:confusion", "train:accuracy"}, "", null){

        public void postParsing(CommandOption.List list) {
            String defaultRawFormatting = "siw";
            for (int argi = 0; argi < this.value.length; ++argi) {
                int j;
                String arg = this.value[argi];
                String[] fields = arg.split("[:=]");
                String dataSet = fields[0];
                String reportOption = fields[1];
                String reportOptionArg = null;
                if (fields.length >= 3) {
                    reportOptionArg = fields[2];
                }
                boolean foundDataSource = false;
                for (int i = 0; i < ReportOption.dataOptions.length; ++i) {
                    if (!dataSet.equals(ReportOption.dataOptions[i])) continue;
                    foundDataSource = true;
                    break;
                }
                if (!foundDataSource) {
                    throw new IllegalArgumentException("Unknown argument = " + dataSet + " in --report " + this.value[argi]);
                }
                boolean foundReportOption = false;
                for (j = 0; j < ReportOption.reportOptions.length; ++j) {
                    if (!reportOption.equals(ReportOption.reportOptions[j])) continue;
                    foundReportOption = true;
                    break;
                }
                if (!foundReportOption) {
                    throw new IllegalArgumentException("Unknown argument = " + reportOption + " in --report " + this.value[argi]);
                }
                ReportOptions[i][j] = true;
                if (j == 1) {
                    if (reportOptionArg == null) {
                        throw new IllegalArgumentException("F1 must have label argument in --report " + this.value[argi]);
                    }
                    ReportOptionArgs[i][j] = reportOptionArg;
                    continue;
                }
                if (reportOptionArg == null) continue;
                throw new IllegalArgumentException("No arguments after = allowed in --report " + this.value[argi]);
            }
        }
    };
    static CommandOption.Object trainerConstructor = new CommandOption.Object(Vectors2Classify.class, "trainer", "ClassifierTrainer constructor", true, new NaiveBayesTrainer(), "Java code for the constructor used to create a ClassifierTrainer.  If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended.You may use this option mutiple times to compare multiple classifiers.", null){

        public void parseArg(String arg) {
            String[] fields = arg.split(",");
            String constructorName = fields[0];
            if (constructorName.indexOf(40) != -1) {
                super.parseArg(arg);
            } else if (constructorName.endsWith("Trainer")) {
                super.parseArg("new " + constructorName + "()");
            } else {
                super.parseArg("new " + constructorName + "Trainer()");
            }
            Method[] methods = this.value.getClass().getMethods();
            for (int i = 1; i < fields.length; ++i) {
                int j;
                Object parameterValueObject;
                String[] nameValuePair = fields[i].split("=");
                String parameterName = nameValuePair[0];
                String parameterValue = nameValuePair[1];
                try {
                    parameterValueObject = 2.getInterpreter().eval(parameterValue);
                }
                catch (EvalError e) {
                    throw new IllegalArgumentException("Java interpreter eval error on parameter " + parameterName + "\n" + (Object)((Object)e));
                }
                boolean foundSetter = false;
                for (j = 0; j < methods.length; ++j) {
                    if (!("set" + Character.toUpperCase(parameterName.charAt(0)) + parameterName.substring(1)).equals(methods[j].getName()) || methods[j].getParameterTypes().length != 1) continue;
                    try {
                        Object[] parameterList = new Object[]{parameterValueObject};
                        methods[j].invoke(this.value, parameterList);
                    }
                    catch (IllegalAccessException e) {
                        System.out.println("IllegalAccessException " + e);
                        throw new IllegalArgumentException("Java access error calling setter\n" + e);
                    }
                    catch (InvocationTargetException e) {
                        System.out.println("IllegalTargetException " + e);
                        throw new IllegalArgumentException("Java target error calling setter\n" + e);
                    }
                    foundSetter = true;
                    break;
                }
                if (foundSetter) continue;
                System.out.println("Parameter " + parameterName + " not found on trainer " + constructorName);
                System.out.println("Available parameters for " + constructorName);
                for (j = 0; j < methods.length; ++j) {
                    if (!methods[j].getName().startsWith("set") || methods[j].getParameterTypes().length != 1) continue;
                    System.out.println(Character.toLowerCase(methods[j].getName().charAt(3)) + methods[j].getName().substring(4));
                }
                throw new IllegalArgumentException("no setter found for parameter " + parameterName);
            }
        }

        public void postParsing(CommandOption.List list) {
            assert (this.value instanceof ClassifierTrainer);
            classifierTrainers.add((ClassifierTrainer)this.value);
        }
    };
    static CommandOption.String outputFile = new CommandOption.String(Vectors2Classify.class, "output-classifier", "FILENAME", true, "classifier.mallet", "The filename in which to write the classifier after it has been trained.", null);
    static CommandOption.String inputFile = new CommandOption.String(Vectors2Classify.class, "input", "FILENAME", true, "text.vectors", "The filename from which to read the list of training instances.  Use - for stdin.", null);
    static CommandOption.String trainingFile = new CommandOption.String(Vectors2Classify.class, "training-file", "FILENAME", true, "text.vectors", "Read the training set instance list from this file. If this is specified, the input file parameter is ignored", null);
    static CommandOption.String testFile = new CommandOption.String(Vectors2Classify.class, "testing-file", "FILENAME", true, "text.vectors", "Read the test set instance list to this file. If this option is specified, the training-file parameter must be specified and  the input-file parameter is ignored", null);
    static CommandOption.String validationFile = new CommandOption.String(Vectors2Classify.class, "validation-file", "FILENAME", true, "text.vectors", "Read the validation set instance list to this file.If this option is specified, the training-file parameter must be specified and the input-file parameter is ignored", null);
    static CommandOption.Double trainingProportionOption = new CommandOption.Double(Vectors2Classify.class, "training-portion", "DECIMAL", true, 1.0, "The fraction of the instances that should be used for training.", null);
    static CommandOption.Double validationProportionOption = new CommandOption.Double(Vectors2Classify.class, "validation-portion", "DECIMAL", true, 0.0, "The fraction of the instances that should be used for validation.", null);
    static CommandOption.Double unlabeledProportionOption = new CommandOption.Double(Vectors2Classify.class, "unlabeled-portion", "DECIMAL", true, 0.0, "The fraction of the training instances that should have their labels hidden.  Note that these are taken out of the training-portion, not allocated separately.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(Vectors2Classify.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    static CommandOption.Integer numTrialsOption = new CommandOption.Integer(Vectors2Classify.class, "num-trials", "INTEGER", true, 1, "The number of random train/test splits to perform", null);
    static CommandOption.Object classifierEvaluatorOption = new CommandOption.Object(Vectors2Classify.class, "classifier-evaluator", "CONSTRUCTOR", true, null, "Java code for constructing a ClassifierEvaluating object", null);
    static CommandOption.Integer verbosityOption = new CommandOption.Integer(Vectors2Classify.class, "verbosity", "INTEGER", true, -1, "The level of messages to print: 0 is silent, 8 is most verbose. Levels 0-8 correspond to the java.logger predefined levels off, severe, warning, info, config, fine, finer, finest, all. The default value is taken from the mallet logging.properties file, which currently defaults to INFO level (3)", null);
    static CommandOption.Boolean noOverwriteProgressMessagesOption = new CommandOption.Boolean(Vectors2Classify.class, "noOverwriteProgressMessages", "true|false", false, false, "Suppress writing-in-place on terminal for progess messages - repetitive messages of which only the latest is generally of interest", null);

    public static void main(String[] args) throws EvalError, IOException {
        CommandOption.setSummary(Vectors2Classify.class, "A tool for training, saving and printing diagnostics from a classifier on vectors.");
        CommandOption.process(Vectors2Classify.class, args);
        if (!trainerConstructor.wasInvoked()) {
            classifierTrainers.add(new NaiveBayesTrainer());
        }
        if (!report.wasInvoked()) {
            report.postParsing(null);
        }
        int verbosity = Vectors2Classify.verbosityOption.value;
        Logger rootLogger = ((MalletLogger)progressLogger).getRootLogger();
        if (verbosityOption.wasInvoked()) {
            rootLogger.setLevel(MalletLogger.LoggingLevels[verbosity]);
        }
        if (!Vectors2Classify.noOverwriteProgressMessagesOption.value) {
            Handler[] handlers = rootLogger.getHandlers();
            for (int i = 0; i < handlers.length; ++i) {
                if (!(handlers[i] instanceof ConsoleHandler)) continue;
                handlers[i].setFormatter(new ProgressMessageLogFormatter());
            }
        }
        boolean separateIlists = testFile.wasInvoked() || trainingFile.wasInvoked() || validationFile.wasInvoked();
        InstanceList ilist = null;
        InstanceList testFileIlist = null;
        InstanceList trainingFileIlist = null;
        InstanceList validationFileIlist = null;
        if (!separateIlists) {
            ilist = InstanceList.load(new File(Vectors2Classify.inputFile.value));
        } else {
            trainingFileIlist = InstanceList.load(new File(Vectors2Classify.trainingFile.value));
            logger.info("Training vectors loaded from " + Vectors2Classify.trainingFile.value);
            if (testFile.wasInvoked()) {
                testFileIlist = InstanceList.load(new File(Vectors2Classify.testFile.value));
                logger.info("Testing vectors loaded from " + Vectors2Classify.testFile.value);
            }
            if (validationFile.wasInvoked()) {
                validationFileIlist = InstanceList.load(new File(Vectors2Classify.validationFile.value));
                logger.info("validation vectors loaded from " + Vectors2Classify.validationFile.value);
            }
        }
        int numTrials = Vectors2Classify.numTrialsOption.value;
        Random r = randomSeedOption.wasInvoked() ? new Random(Vectors2Classify.randomSeedOption.value) : new Random();
        ClassifierTrainer[] trainers = new ClassifierTrainer[classifierTrainers.size()];
        for (int i = 0; i < classifierTrainers.size(); ++i) {
            trainers[i] = classifierTrainers.get(i);
            logger.fine("Trainer specified = " + trainers[i].toString());
        }
        double[][] trainAccuracy = new double[trainers.length][numTrials];
        double[][] testAccuracy = new double[trainers.length][numTrials];
        double[][] validationAccuracy = new double[trainers.length][numTrials];
        String[][] trainConfusionMatrix = new String[trainers.length][numTrials];
        String[][] testConfusionMatrix = new String[trainers.length][numTrials];
        String[][] validationConfusionMatrix = new String[trainers.length][numTrials];
        double t = Vectors2Classify.trainingProportionOption.value;
        double v = Vectors2Classify.validationProportionOption.value;
        if (!separateIlists) {
            logger.info("Training portion = " + t);
            logger.info(" Unlabeled training sub-portion = " + Vectors2Classify.unlabeledProportionOption.value);
            logger.info("Validation portion = " + v);
            logger.info("Testing portion = " + (1.0 - v - t));
        }
        for (int trialIndex = 0; trialIndex < numTrials; ++trialIndex) {
            System.out.println("\n-------------------- Trial " + trialIndex + "  --------------------\n");
            BitSet unlabeledIndices = null;
            InstanceList[] ilists = !separateIlists ? ilist.split(r, new double[]{t, 1.0 - t - v, v}) : new InstanceList[]{trainingFileIlist, testFileIlist, validationFileIlist};
            if (Vectors2Classify.unlabeledProportionOption.value > 0.0) {
                unlabeledIndices = new Randoms(r.nextInt()).nextBitSet(ilists[0].size(), Vectors2Classify.unlabeledProportionOption.value);
            }
            long[] time = new long[trainers.length];
            for (int c = 0; c < trainers.length; ++c) {
                String label;
                time[c] = System.currentTimeMillis();
                System.out.println("Trial " + trialIndex + " Training " + trainers[c].toString() + " with " + ilists[0].size() + " instances");
                if (Vectors2Classify.unlabeledProportionOption.value > 0.0) {
                    ilists[0].hideSomeLabels(unlabeledIndices);
                }
                trainers[c].setValidationInstances(ilists[2]);
                Object classifier = trainers[c].train(ilists[0]);
                if (Vectors2Classify.unlabeledProportionOption.value > 0.0) {
                    ilists[0].unhideAllLabels();
                }
                System.out.println("Trial " + trialIndex + " Training " + trainers[c].toString() + " finished");
                time[c] = System.currentTimeMillis() - time[c];
                Trial trainTrial = new Trial((Classifier)classifier, ilists[0]);
                Trial testTrial = new Trial((Classifier)classifier, ilists[1]);
                Trial validationTrial = new Trial((Classifier)classifier, ilists[2]);
                if (ReportOptions[0][2] && ilists[0].size() > 0) {
                    trainConfusionMatrix[c][trialIndex] = new ConfusionMatrix(trainTrial).toString();
                }
                if (ReportOptions[1][2] && ilists[1].size() > 0) {
                    testConfusionMatrix[c][trialIndex] = new ConfusionMatrix(testTrial).toString();
                }
                if (ReportOptions[2][2] && ilists[2].size() > 0) {
                    validationConfusionMatrix[c][trialIndex] = new ConfusionMatrix(validationTrial).toString();
                }
                if (ReportOptions[0][0]) {
                    trainAccuracy[c][trialIndex] = trainTrial.getAccuracy();
                }
                if (ReportOptions[1][0]) {
                    testAccuracy[c][trialIndex] = testTrial.getAccuracy();
                }
                if (ReportOptions[2][0]) {
                    validationAccuracy[c][trialIndex] = validationTrial.getAccuracy();
                }
                if (outputFile.wasInvoked()) {
                    String filename = Vectors2Classify.outputFile.value;
                    if (trainers.length > 1) {
                        filename = filename + trainers[c].toString();
                    }
                    if (numTrials > 1) {
                        filename = filename + ".trial" + trialIndex;
                    }
                    try {
                        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(filename));
                        oos.writeObject(classifier);
                        oos.close();
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                        throw new IllegalArgumentException("Couldn't write classifier to filename " + filename);
                    }
                }
                if (ReportOptions[0][3]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
                    System.out.println(" Raw Training Data");
                    Vectors2Classify.printTrialClassification(trainTrial);
                }
                if (ReportOptions[1][3]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
                    System.out.println(" Raw Testing Data");
                    Vectors2Classify.printTrialClassification(testTrial);
                }
                if (ReportOptions[2][3]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString());
                    System.out.println(" Raw Validation Data");
                    Vectors2Classify.printTrialClassification(validationTrial);
                }
                if (ReportOptions[0][2]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Training Data Confusion Matrix");
                    if (ilists[0].size() > 0) {
                        System.out.println(trainConfusionMatrix[c][trialIndex]);
                    }
                }
                if (ReportOptions[0][0]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data accuracy= " + trainAccuracy[c][trialIndex]);
                }
                if (ReportOptions[0][1]) {
                    label = ReportOptionArgs[0][1];
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " training data F1(" + label + ") = " + trainTrial.getF1(label));
                }
                if (ReportOptions[2][2]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Validation Data Confusion Matrix");
                    if (ilists[2].size() > 0) {
                        System.out.println(validationConfusionMatrix[c][trialIndex]);
                    }
                }
                if (ReportOptions[2][0]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data accuracy= " + validationAccuracy[c][trialIndex]);
                }
                if (ReportOptions[2][1]) {
                    label = ReportOptionArgs[2][1];
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " validation data F1(" + label + ") = " + validationTrial.getF1(label));
                }
                if (ReportOptions[1][2]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " Test Data Confusion Matrix");
                    if (ilists[1].size() > 0) {
                        System.out.println(testConfusionMatrix[c][trialIndex]);
                    }
                }
                if (ReportOptions[1][0]) {
                    System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data accuracy= " + testAccuracy[c][trialIndex]);
                }
                if (!ReportOptions[1][1]) continue;
                label = ReportOptionArgs[1][1];
                System.out.println("Trial " + trialIndex + " Trainer " + trainers[c].toString() + " test data F1(" + label + ") = " + testTrial.getF1(label));
            }
        }
        for (int c = 0; c < trainers.length; ++c) {
            System.out.println("\n" + trainers[c].toString());
            if (ReportOptions[0][0]) {
                System.out.println("Summary. train accuracy mean = " + MatrixOps.mean(trainAccuracy[c]) + " stddev = " + MatrixOps.stddev(trainAccuracy[c]) + " stderr = " + MatrixOps.stderr(trainAccuracy[c]));
            }
            if (ReportOptions[2][0]) {
                System.out.println("Summary. validation accuracy mean = " + MatrixOps.mean(validationAccuracy[c]) + " stddev = " + MatrixOps.stddev(validationAccuracy[c]) + " stderr = " + MatrixOps.stderr(validationAccuracy[c]));
            }
            if (!ReportOptions[1][0]) continue;
            System.out.println("Summary. test accuracy mean = " + MatrixOps.mean(testAccuracy[c]) + " stddev = " + MatrixOps.stddev(testAccuracy[c]) + " stderr = " + MatrixOps.stderr(testAccuracy[c]));
        }
    }

    private static void printTrialClassification(Trial trial) {
        for (Classification c : trial) {
            Instance instance = c.getInstance();
            System.out.print(instance.getName() + " " + instance.getTarget() + " ");
            Labeling labeling = c.getLabeling();
            for (int j = 0; j < labeling.numLocations(); ++j) {
                System.out.print(labeling.getLabelAtRank(j).toString() + ":" + labeling.getValueAtRank(j) + " ");
            }
            System.out.println();
        }
    }

    private static class ReportOption {
        static final String[] dataOptions = new String[]{"train", "test", "validation"};
        static final String[] reportOptions = new String[]{"accuracy", "f1", "confusion", "raw"};
        static final int train = 0;
        static final int test = 1;
        static final int validation = 2;
        static final int accuracy = 0;
        static final int f1 = 1;
        static final int confusion = 2;
        static final int raw = 3;

        private ReportOption() {
        }
    }
}

