/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.cluster.neighbor_evaluator;

import cc.mallet.classify.Classifier;
import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.neighbor_evaluator.ClassifyingNeighborEvaluator;
import cc.mallet.cluster.neighbor_evaluator.Neighbor;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.types.MatrixOps;

public class MedoidEvaluator
extends ClassifyingNeighborEvaluator {
    private static final long serialVersionUID = 1L;
    boolean singleLink = false;
    CombiningStrategy combiningStrategy;
    boolean mergeFirst = true;
    PairwiseMatrix scoreCache;

    public MedoidEvaluator(Classifier classifier, String scoringLabel) {
        super(classifier, scoringLabel);
        System.out.println("Using Medoid Evaluator");
    }

    public MedoidEvaluator(Classifier classifier, String scoringLabel, boolean singleLink, boolean mergeFirst) {
        super(classifier, scoringLabel);
        this.singleLink = singleLink;
        this.mergeFirst = mergeFirst;
        System.out.println("Using Medoid Evaluator. Single link=" + singleLink + ".");
    }

    @Override
    public double[] evaluate(Neighbor[] neighbors) {
        double[] scores = new double[neighbors.length];
        for (int i = 0; i < neighbors.length; ++i) {
            scores[i] = this.evaluate(neighbors[i]);
        }
        return scores;
    }

    @Override
    public double evaluate(Neighbor neighbor) {
        double interScore;
        AgglomerativeNeighbor pwn;
        int j;
        int i;
        int[] result = new int[2];
        if (!(neighbor instanceof AgglomerativeNeighbor)) {
            throw new IllegalArgumentException("Expect AgglomerativeNeighbor not " + neighbor.getClass().getName());
        }
        int[][] oldIndices = ((AgglomerativeNeighbor)neighbor).getOldClusters();
        int[] mergedIndices = ((AgglomerativeNeighbor)neighbor).getNewCluster();
        Clustering original = neighbor.getOriginal();
        result[0] = this.getCentroid(oldIndices[0], original);
        result[1] = this.getCentroid(oldIndices[1], original);
        if (this.singleLink) {
            AgglomerativeNeighbor pwn2 = new AgglomerativeNeighbor(original, original, oldIndices[0][result[0]], oldIndices[1][result[1]]);
            double score = this.getScore(pwn2);
            return score;
        }
        double[] medsA = this.getMedWeights(result[0], oldIndices[0], original);
        double[] medsB = this.getMedWeights(result[1], oldIndices[1], original);
        double numerator = 0.0;
        double denominator = 0.0;
        for (i = 0; i < oldIndices[0].length; ++i) {
            for (j = 0; j < oldIndices[1].length; ++j) {
                pwn = new AgglomerativeNeighbor(original, original, oldIndices[0][i], oldIndices[1][j]);
                interScore = this.getScore(pwn);
                numerator += interScore * medsA[i] * medsB[j];
                denominator += medsA[i] * medsB[j];
            }
            if (!this.mergeFirst) continue;
            for (j = i + 1; j < oldIndices[0].length; ++j) {
                pwn = new AgglomerativeNeighbor(original, original, oldIndices[0][i], oldIndices[0][j]);
                interScore = this.getScore(pwn);
                numerator += interScore * medsA[i] * medsA[j];
                denominator += medsA[i] * medsA[j];
            }
        }
        if (this.mergeFirst) {
            for (i = 0; i < oldIndices[1].length; ++i) {
                for (j = i + 1; j < oldIndices[1].length; ++j) {
                    pwn = new AgglomerativeNeighbor(original, original, oldIndices[1][i], oldIndices[1][j]);
                    interScore = this.getScore(pwn);
                    numerator += interScore * medsB[i] * medsB[j];
                    denominator += medsB[i] * medsB[j];
                }
            }
        }
        return numerator / denominator;
    }

    private double[] getMedWeights(int medIdx, int[] indices, Clustering original) {
        double[] result = new double[indices.length];
        for (int i = 0; i < result.length; ++i) {
            if (medIdx == i) {
                result[i] = 1.0;
                continue;
            }
            AgglomerativeNeighbor an = new AgglomerativeNeighbor(original, original, indices[medIdx], indices[i]);
            result[i] = this.getScore(an);
        }
        return result;
    }

    private int getCentroid(int[] indices, Clustering original) {
        int i;
        if (indices.length < 2) {
            return 0;
        }
        double centDist = Double.NEGATIVE_INFINITY;
        int centIdx = -1;
        double[] scores = new double[indices.length];
        for (i = 0; i < indices.length; ++i) {
            double acc = 0.0;
            for (int k = 0; k < indices.length && i != k; ++k) {
                AgglomerativeNeighbor pwn = new AgglomerativeNeighbor(original, original, indices[i], indices[k]);
                double score = this.getScore(pwn);
                acc += score;
            }
            scores[i] = acc /= (double)(indices.length - 1);
        }
        for (i = 0; i < scores.length; ++i) {
            if (!(scores[i] > centDist)) continue;
            centDist = scores[i];
            centIdx = i;
        }
        return centIdx;
    }

    @Override
    public void reset() {
        this.scoreCache = null;
    }

    @Override
    public String toString() {
        return "class=" + this.getClass().getName() + " classifier=" + this.classifier.getClass().getName();
    }

    private double getScore(AgglomerativeNeighbor pwneighbor) {
        int[] indices;
        if (this.scoreCache == null) {
            this.scoreCache = new PairwiseMatrix(pwneighbor.getOriginal().getNumInstances());
        }
        if (this.scoreCache.get((indices = pwneighbor.getNewCluster())[0], indices[1]) == 0.0) {
            this.scoreCache.set(indices[0], indices[1], this.classifier.classify(pwneighbor).getLabelVector().value(this.scoringLabel));
        }
        return this.scoreCache.get(indices[0], indices[1]);
    }

    public static class Maximum
    implements CombiningStrategy {
        @Override
        public double combine(double[] scores) {
            return MatrixOps.max(scores);
        }
    }

    public static class Minimum
    implements CombiningStrategy {
        @Override
        public double combine(double[] scores) {
            return MatrixOps.min(scores);
        }
    }

    public static class Average
    implements CombiningStrategy {
        @Override
        public double combine(double[] scores) {
            return MatrixOps.mean(scores);
        }
    }

    public static interface CombiningStrategy {
        public double combine(double[] var1);
    }
}

