cc.mallet.fst.semi_supervised
Class GELattice
java.lang.Object
cc.mallet.fst.semi_supervised.GELattice
public class GELattice
- extends java.lang.Object
Runs the dynamic programming algorithm of [Mann and McCallum 08] for
computing the gradient of a Generalized Expectation constraint that
considers a single label of a linear chain CRF.
See:
"Generalized Expectation Criteria for Semi-Supervised Learning of Conditional Random Fields"
Gideon Mann and Andrew McCallum
ACL 2008
gdruck NOTE: This new version of GE Lattice that computes the gradient
for all constraints simultaneously!
- Author:
- Gregory Druck, Gaurav Chandalia, Gideon Mann
Nested Class Summary |
protected class |
GELattice.LatticeNode
Contains forward-backward vectors correspoding to an input position and a
state index. |
Methods inherited from class java.lang.Object |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
latticeLength
protected int latticeLength
transducer
protected Transducer transducer
numStates
protected int numStates
lattice
protected GELattice.LatticeNode[][] lattice
dotCache
protected LogNumber[][][] dotCache
GELattice
public GELattice(FeatureVectorSequence fvs,
double[][] gammas,
double[][][] xis,
Transducer transducer,
int[][] reverseTrans,
int[][] reverseTransIndices,
CRF.Factors gradient,
java.util.ArrayList<GEConstraint> constraints,
boolean check)
- Parameters:
fvs
- Input FeatureVectorSequencegammas
- Marginals over single statesxis
- Marginals over pairs of statestransducer
- TransducerreverseTrans
- Source state indices for each destination statereverseTransIndices
- Transition indices for each destination stategradient
- Gradient to incrementconstraints
- List of constraintscheck
- Whether to run the debugging test to verify correctness (will be much slower if true)
check
public void check(java.util.ArrayList<GEConstraint> constraints,
double[][] gammas,
double[][][] xis,
FeatureVectorSequence fvs)
- Verifies the correctness of the lattice computations.
getAlpha
public LogNumber getAlpha(int ip,
int s1,
int s2)
getBeta
public LogNumber getBeta(int ip,
int s1,
int s2)