/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CacheStaleIndicator;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

public class ThreadedOptimizable
implements Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(ThreadedOptimizable.class.getName());
    protected InstanceList trainingSet;
    protected Optimizable.ByCombiningBatchGradient optimizable;
    protected double[] batchCachedValue;
    protected List<double[]> batchCachedGradient;
    protected CacheStaleIndicator cacheIndicator;
    private transient Collection<Callable<Double>> valueTasks;
    private transient Collection<Callable<Boolean>> gradientTasks;
    private transient ThreadPoolExecutor executor;
    public static final int SLEEP_TIME = 100;

    public ThreadedOptimizable(Optimizable.ByCombiningBatchGradient optimizable, InstanceList trainingSet, int numFactors, CacheStaleIndicator cacheIndicator) {
        this.trainingSet = trainingSet;
        this.optimizable = optimizable;
        int numBatches = optimizable.getNumBatches();
        assert (numBatches > 0) : "Invalid number of batches: " + numBatches;
        this.batchCachedValue = new double[numBatches];
        this.batchCachedGradient = new ArrayList<double[]>(numBatches);
        for (int i = 0; i < numBatches; ++i) {
            this.batchCachedGradient.add(new double[numFactors]);
        }
        this.cacheIndicator = cacheIndicator;
        logger.info("Creating " + numBatches + " threads for updating gradient...");
        this.executor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numBatches);
        this.createTasks();
    }

    public Optimizable.ByCombiningBatchGradient getOptimizable() {
        return this.optimizable;
    }

    public void shutdown() {
        assert (this.executor.shutdownNow().size() == 0) : "All tasks didn't finish";
    }

    public double getValue() {
        if (this.cacheIndicator.isValueStale()) {
            try {
                List<Future<Double>> results = this.executor.invokeAll(this.valueTasks);
                int batch = 0;
                for (Future<Double> f : results) {
                    try {
                        this.batchCachedValue[batch++] = f.get();
                    }
                    catch (ExecutionException ee) {
                        ee.printStackTrace();
                    }
                }
            }
            catch (InterruptedException ie) {
                ie.printStackTrace();
            }
        }
        double cachedValue = MatrixOps.sum(this.batchCachedValue);
        logger.info("getValue() (loglikelihood, optimizable by label likelihood) =" + cachedValue);
        return cachedValue;
    }

    public void getValueGradient(double[] buffer) {
        if (this.cacheIndicator.isGradientStale()) {
            this.getValue();
            try {
                this.executor.invokeAll(this.gradientTasks);
            }
            catch (InterruptedException ie) {
                ie.printStackTrace();
            }
        }
        this.optimizable.combineGradients(this.batchCachedGradient, buffer);
    }

    protected void createTasks() {
        int numBatches = this.optimizable.getNumBatches();
        this.valueTasks = new ArrayList<Callable<Double>>(numBatches);
        this.gradientTasks = new ArrayList<Callable<Boolean>>(numBatches);
        int numBatchInstances = this.trainingSet.size() / numBatches;
        int start = -1;
        int end = -1;
        for (int i = 0; i < numBatches; ++i) {
            if (i == 0) {
                start = 0;
                end = start + numBatchInstances;
            } else if (i == numBatches - 1) {
                start = end;
                end = this.trainingSet.size();
            } else {
                start = end;
                end = start + numBatchInstances;
            }
            this.valueTasks.add(new ValueHandler(i, new int[]{start, end}));
            this.gradientTasks.add(new GradientHandler(i, new int[]{start, end}));
        }
    }

    public int getNumParameters() {
        return this.optimizable.getNumParameters();
    }

    public void getParameters(double[] buffer) {
        this.optimizable.getParameters(buffer);
    }

    public double getParameter(int index) {
        return this.optimizable.getParameter(index);
    }

    public void setParameters(double[] buff) {
        this.optimizable.setParameters(buff);
    }

    public void setParameter(int index, double value) {
        this.optimizable.setParameter(index, value);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class GradientHandler
    implements Callable<Boolean> {
        private int batchIndex;
        private int[] batchAssignments;

        public GradientHandler(int batchIndex, int[] batchAssignments) {
            this.batchIndex = batchIndex;
            this.batchAssignments = batchAssignments;
        }

        @Override
        public Boolean call() {
            ThreadedOptimizable.this.optimizable.getBatchValueGradient(ThreadedOptimizable.this.batchCachedGradient.get(this.batchIndex), this.batchIndex, this.batchAssignments);
            return true;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class ValueHandler
    implements Callable<Double> {
        private int batchIndex;
        private int[] batchAssignments;

        public ValueHandler(int batchIndex, int[] batchAssignments) {
            this.batchIndex = batchIndex;
            this.batchAssignments = batchAssignments;
        }

        @Override
        public Double call() {
            return ThreadedOptimizable.this.optimizable.getBatchValue(this.batchIndex, this.batchAssignments);
        }
    }
}

