cc.mallet.types
Class CrossValidationIterator

java.lang.Object
  extended by cc.mallet.types.CrossValidationIterator
All Implemented Interfaces:
java.io.Serializable, java.util.Iterator<InstanceList[]>

public class CrossValidationIterator
extends java.lang.Object
implements java.util.Iterator<InstanceList[]>, java.io.Serializable

An iterator which splits an InstanceList into n-folds and iterates over the folds for use in n-fold cross-validation. For each iteration, list[0] contains a InstanceList with n-1 folds typically used for training and list[1] contains an InstanceList with 1 fold typically used for validation. This class uses MultiInstanceList to avoid creating a new InstanceList each iteration. TODO - currently the distribution is completely random, an improvement would be to provide a stratified random distribution.

Author:
Aron Culotta culotta@cs.umass.edu
See Also:
MultiInstanceList, InstanceList, Serialized Form

Constructor Summary
CrossValidationIterator(InstanceList ilist, int _nfolds)
          Constructs a new n-fold cross-validation iterator
CrossValidationIterator(InstanceList ilist, int nfolds, java.util.Random r)
          Constructs a new n-fold cross-validation iterator
 
Method Summary
 void clear()
          Calls clear on each fold.
 boolean hasNext()
           
 InstanceList[] next()
          Returns the next training/testing split.
 InstanceList[] nextSplit()
          Returns the next training/testing split.
 InstanceList[] nextSplit(int numTrainFolds)
          Returns the next training/testing split.
 void remove()
           
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

CrossValidationIterator

public CrossValidationIterator(InstanceList ilist,
                               int nfolds,
                               java.util.Random r)
Constructs a new n-fold cross-validation iterator

Parameters:
ilist - instance list to split into folds and iterate over
nfolds - number of folds to split InstanceList into
r - The source of randomness to use in shuffling.

CrossValidationIterator

public CrossValidationIterator(InstanceList ilist,
                               int _nfolds)
Constructs a new n-fold cross-validation iterator

Parameters:
ilist - instance list to split into folds and iterate over
_nfolds - number of folds to split InstanceList into
Method Detail

clear

public void clear()
Calls clear on each fold. It is recommended that this be always be called when the iterator is no longer needed so that implementations of InstanceList such as PagedInstanceList can clean up any temporary data they may have outside the JVM.


hasNext

public boolean hasNext()
Specified by:
hasNext in interface java.util.Iterator<InstanceList[]>

nextSplit

public InstanceList[] nextSplit()
Returns the next training/testing split.

Returns:
A two element array of InstanceList, where InstanceList[0] contains n-1 folds for training and InstanceList[1] contains 1 fold for testing.

nextSplit

public InstanceList[] nextSplit(int numTrainFolds)
Returns the next training/testing split.

Returns:
A two element array of InstanceList, where InstanceList[0] contains numTrainingFolds folds for training and InstanceList[1] contains n - numTrainingFolds folds for testing.

next

public InstanceList[] next()
Returns the next training/testing split.

Specified by:
next in interface java.util.Iterator<InstanceList[]>
Returns:
A two element array of InstanceList, where InstanceList[0] contains n-1 folds for training and InstanceList[1] contains 1 fold for testing.
See Also:
Iterator.next()

remove

public void remove()
Specified by:
remove in interface java.util.Iterator<InstanceList[]>