/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.types;

import cc.mallet.grmm.types.AbstractFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.EVD;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.NotConvergedException;
import no.uib.cipr.matrix.Vector;

public class NormalFactor
extends AbstractFactor {
    private Vector mean;
    private Matrix variance;

    public NormalFactor(VarSet vars, Vector mean, Matrix variance) {
        super(vars);
        if (!this.isPosDef(variance)) {
            throw new IllegalArgumentException("Matrix " + variance + " not positive definite.");
        }
        this.mean = mean;
        this.variance = variance;
    }

    private boolean isPosDef(Matrix variance) {
        try {
            EVD evd = EVD.factorize((Matrix)variance);
            double[] vals = evd.getRealEigenvalues();
            return vals[vals.length - 1] > 0.0;
        }
        catch (NotConvergedException e) {
            throw new RuntimeException(e);
        }
    }

    protected Factor extractMaxInternal(VarSet varSet) {
        throw new UnsupportedOperationException();
    }

    public double value(Assignment assn) {
        return 1.0;
    }

    protected double lookupValueInternal(int i) {
        throw new UnsupportedOperationException();
    }

    protected Factor marginalizeInternal(VarSet varsToKeep) {
        throw new UnsupportedOperationException();
    }

    public Factor normalize() {
        return this;
    }

    public Assignment sample(Randoms r) {
        double[] vals = new double[this.mean.size()];
        for (int k = 0; k < vals.length; ++k) {
            vals[k] = r.nextGaussian();
        }
        DenseVector Z = new DenseVector(vals, false);
        DenseVector result = new DenseVector(vals.length);
        this.variance.mult((Vector)Z, (Vector)result);
        result = (DenseVector)result.add(this.mean);
        return new Assignment(this.vars.toVariableArray(), result.getData());
    }

    public boolean almostEquals(Factor p, double epsilon) {
        return this.equals(p);
    }

    public Factor duplicate() {
        return new NormalFactor(this.vars, this.mean, this.variance);
    }

    public boolean isNaN() {
        return false;
    }

    public String dumpToString() {
        return this.toString();
    }

    public String toString() {
        return "[NormalFactor " + this.vars + " " + this.mean + " ... " + this.variance + " ]";
    }

    public Factor slice(Assignment assn) {
        if (assn.varSet().containsAll(this.vars)) {
            return new ConstantFactor(this.value(assn));
        }
        throw new UnsupportedOperationException();
    }

    public void multiplyBy(Factor f) {
        double val;
        if (f instanceof ConstantFactor && Maths.almostEquals(val = f.value(new Assignment()), 1.0)) {
            return;
        }
        throw new UnsupportedOperationException("Can't multiply NormalFactor by " + f);
    }

    public void divideBy(Factor f) {
        double val;
        if (f instanceof ConstantFactor && Maths.almostEquals(val = f.value(new Assignment()), 1.0)) {
            return;
        }
        throw new UnsupportedOperationException("Can't divide NormalFactor by " + f);
    }
}

