/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify.tests;

import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.DecisionTreeTrainer;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.NaiveBayesTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.iterator.RandomTokenSequenceIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

public class TestClassifiers
extends TestCase {
    public TestClassifiers(String name) {
        super(name);
    }

    private static Alphabet dictOfSize(int size) {
        Alphabet ret = new Alphabet();
        for (int i = 0; i < size; ++i) {
            ret.lookupIndex("feature" + i);
        }
        return ret;
    }

    public void testRandomTrained() {
        int i;
        ClassifierTrainer[] trainers = new ClassifierTrainer[]{new NaiveBayesTrainer(), new MaxEntTrainer(), new DecisionTreeTrainer()};
        Alphabet fd = TestClassifiers.dictOfSize(3);
        String[] classNames = new String[]{"class0", "class1", "class2"};
        InstanceList ilist = new InstanceList(new Randoms(1), fd, classNames, 200);
        InstanceList[] lists = ilist.split(new Random(2L), new double[]{0.5, 0.5});
        Classifier[] classifiers = new Classifier[trainers.length];
        for (i = 0; i < trainers.length; ++i) {
            classifiers[i] = trainers[i].train(lists[0]);
        }
        System.out.println("Accuracy on training set:");
        for (i = 0; i < trainers.length; ++i) {
            System.out.println(classifiers[i].getClass().getName() + ": " + new Trial(classifiers[i], lists[0]).getAccuracy());
        }
        System.out.println("Accuracy on testing set:");
        for (i = 0; i < trainers.length; ++i) {
            System.out.println(classifiers[i].getClass().getName() + ": " + new Trial(classifiers[i], lists[1]).getAccuracy());
        }
    }

    public void tetsNewFeatures() {
        int i;
        int i2;
        ClassifierTrainer[] trainers = new ClassifierTrainer[]{new MaxEntTrainer()};
        Alphabet fd = TestClassifiers.dictOfSize(3);
        String[] classNames = new String[]{"class0", "class1", "class2"};
        Randoms r = new Randoms(1);
        InstanceList training = new InstanceList(r, fd, classNames, 50);
        this.expandDict(fd, 25);
        Classifier[] classifiers = new Classifier[trainers.length];
        for (i2 = 0; i2 < trainers.length; ++i2) {
            classifiers[i2] = trainers[i2].train(training);
        }
        System.out.println("Accuracy on training set:");
        for (i2 = 0; i2 < trainers.length; ++i2) {
            System.out.println(classifiers[i2].getClass().getName() + ": " + new Trial(classifiers[i2], training).getAccuracy());
        }
        InstanceList testing = new InstanceList(training.getPipe());
        RandomTokenSequenceIterator iter = new RandomTokenSequenceIterator(r, new Dirichlet(fd, 2.0), 30.0, 0.0, 10.0, 50.0, classNames);
        testing.addThruPipe(iter);
        for (i = 0; i < testing.size(); ++i) {
            Instance inst = (Instance)testing.get(i);
            System.out.println("DATA:" + inst.getData());
        }
        System.out.println("Accuracy on testing set:");
        for (i = 0; i < trainers.length; ++i) {
            System.out.println(classifiers[i].getClass().getName() + ": " + new Trial(classifiers[i], testing).getAccuracy());
        }
    }

    private void expandDict(Alphabet fd, int size) {
        fd.startGrowth();
        for (int i = 0; i < size; ++i) {
            fd.lookupIndex("feature" + i, true);
        }
    }

    public static Test suite() {
        return new TestSuite(TestClassifiers.class);
    }

    protected void setUp() {
    }

    public static void main(String[] args) {
        TestRunner.run((Test)TestClassifiers.suite());
    }
}

