/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.mix;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.scoringFunctions.AbstractNormalizableScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.NormalizedScoringFunction;
import de.jstacs.scoringFunctions.mix.motifSearch.DurationScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Arrays;

public abstract class AbstractMixtureScoringFunction
extends AbstractNormalizableScoringFunction {
    private int starts;
    protected int[] paramRef;
    protected boolean optimizeHidden;
    protected boolean freeParams;
    private boolean plugIn;
    protected NormalizableScoringFunction[] function;
    protected double[] hiddenParameter;
    protected double[] logHiddenPotential;
    protected double[] hiddenPotential;
    protected double[] componentScore;
    protected double[] partNorm;
    protected double norm;
    protected double logHiddenNorm;
    protected double logGammaSum;
    protected DoubleList[] dList;
    protected IntList[] iList;
    protected boolean isNormalized;

    protected static final int getMaxIndex(double[] w) {
        int max = 0;
        for (int i = 1; i < w.length; ++i) {
            if (!(w[i] > w[max])) continue;
            max = i;
        }
        return max;
    }

    protected AbstractMixtureScoringFunction(int length, int starts, int dimension, boolean optimizeHidden, boolean plugIn, NormalizableScoringFunction ... function) throws CloneNotSupportedException {
        super(function[0].getAlphabetContainer(), length);
        this.function = (NormalizableScoringFunction[])ArrayHandler.clone((Cloneable[])function);
        if (starts < 1) {
            throw new IllegalArgumentException("The number of recommended starts has to be positive.");
        }
        this.starts = starts;
        if (dimension == 0) {
            throw new IllegalArgumentException("The number of components has to be positive.");
        }
        this.isNormalized = AbstractMixtureScoringFunction.isNormalized(function);
        this.hiddenParameter = new double[dimension];
        this.logHiddenPotential = new double[dimension];
        this.hiddenPotential = new double[dimension];
        this.partNorm = new double[dimension];
        this.setHiddenParameters(this.hiddenParameter, 0);
        this.componentScore = new double[dimension];
        this.optimizeHidden = optimizeHidden && dimension > 1;
        this.plugIn = plugIn;
        this.paramRef = null;
        this.init(this.freeParams);
    }

    protected void computeLogGammaSum() {
        this.logGammaSum = 0.0;
        int n = this.getNumberOfComponents();
        if (n > 1 && this.getEss() > 0.0) {
            double sum = 0.0;
            for (int i = 0; i < n; ++i) {
                double h = this.getHyperparameterForHiddenParameter(i);
                sum += h;
                this.logGammaSum -= Gamma.logOfGamma((double)h);
            }
            this.logGammaSum += Gamma.logOfGamma((double)sum);
        }
    }

    protected AbstractMixtureScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    public AbstractMixtureScoringFunction clone() throws CloneNotSupportedException {
        AbstractMixtureScoringFunction clone = (AbstractMixtureScoringFunction)super.clone();
        clone.cloneFunctions(this.function);
        clone.hiddenParameter = (double[])this.hiddenParameter.clone();
        clone.logHiddenPotential = (double[])this.logHiddenPotential.clone();
        clone.hiddenPotential = (double[])this.hiddenPotential.clone();
        clone.componentScore = (double[])this.componentScore.clone();
        clone.partNorm = (double[])this.partNorm.clone();
        clone.iList = null;
        clone.paramRef = null;
        clone.init(this.freeParams);
        return clone;
    }

    protected void cloneFunctions(NormalizableScoringFunction[] originalFunctions) throws CloneNotSupportedException {
        this.function = (NormalizableScoringFunction[])ArrayHandler.clone((Cloneable[])originalFunctions);
    }

    public abstract double getHyperparameterForHiddenParameter(int var1);

    public double getLogPriorTerm() {
        int i;
        double val = 0.0;
        double sum = 0.0;
        for (i = 0; i < this.hiddenParameter.length; ++i) {
            double h = this.getHyperparameterForHiddenParameter(i);
            sum += h;
            val += this.hiddenParameter[i] * h;
        }
        if (this.isNormalized) {
            val -= sum * this.logHiddenNorm;
        }
        for (i = 0; i < this.function.length; ++i) {
            val += this.function[i].getLogPriorTerm();
        }
        return val + this.logGammaSum;
    }

    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i;
        int j = this.function.length + 1;
        for (i = 0; i < this.function.length; ++i) {
            this.function[i].addGradientOfLogPriorTerm(grad, start + this.paramRef[i]);
        }
        j = start + this.paramRef[j];
        start += this.paramRef[this.function.length];
        double e = this.getEss();
        i = 0;
        while (start < j) {
            int n = start;
            grad[n] = grad[n] + (this.getHyperparameterForHiddenParameter(i) - (this.isNormalized ? e * this.hiddenPotential[i] : 0.0));
            ++i;
            ++start;
        }
    }

    public int getIndexOfMaximalComponentFor(Sequence seq, int start) {
        this.fillComponentScores(seq, start);
        return AbstractMixtureScoringFunction.getMaxIndex(this.componentScore);
    }

    public double[] getCurrentParameterValues() throws Exception {
        int numPars = this.getNumberOfParameters();
        if (numPars == -1) {
            throw new Exception("No parameters exists, yet.");
        }
        double[] current = new double[numPars];
        int i = 0;
        int j = this.function.length;
        while (i < j) {
            double[] part = this.function[i].getCurrentParameterValues();
            System.arraycopy(part, 0, current, this.paramRef[i++], part.length);
        }
        System.arraycopy(this.hiddenParameter, 0, current, this.paramRef[j], this.paramRef[j + 1] - this.paramRef[j]);
        return current;
    }

    public double getLogScore(Sequence seq, int start) {
        this.fillComponentScores(seq, start);
        return Normalisation.getLogSum(this.componentScore);
    }

    public final double getNormalizationConstant() {
        if (this.norm < 0.0) {
            this.precomputeNorm();
        }
        return this.norm;
    }

    public final int getNumberOfComponents() {
        return this.componentScore.length;
    }

    public final int getNumberOfParameters() {
        if (this.paramRef == null) {
            return -1;
        }
        return this.paramRef[this.paramRef.length - 1];
    }

    public final int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    public double[] getProbsForComponent(Sequence seq) {
        this.fillComponentScores(seq, 0);
        double[] p = new double[this.componentScore.length];
        Normalisation.logSumNormalisation(this.componentScore, 0, p.length, p, 0);
        return p;
    }

    public NormalizableScoringFunction[] getScoringFunctions() throws CloneNotSupportedException {
        return (NormalizableScoringFunction[])ArrayHandler.clone((Cloneable[])this.function);
    }

    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        int[] ind = this.getIndices(index);
        if (ind[0] == this.function.length) {
            return this.hiddenParameter.length;
        }
        return this.function[ind[0]].getSizeOfEventSpaceForRandomVariablesOfParameter(ind[1]);
    }

    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        if (this.plugIn) {
            this.initializeUsingPlugIn(index, freeParams, data, weights);
            this.init(freeParams);
        } else {
            this.initializeFunctionRandomly(freeParams);
        }
    }

    protected abstract void initializeUsingPlugIn(int var1, boolean var2, Sample[] var3, double[][] var4) throws Exception;

    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        for (int i = 0; i < this.function.length; ++i) {
            this.function[i].initializeFunctionRandomly(freeParams);
        }
        if (this.optimizeHidden) {
            this.initializeHiddenPotentialRandomly();
        }
        this.init(freeParams);
    }

    protected void initializeHiddenPotentialRandomly() {
        double[] h = new double[this.getNumberOfComponents()];
        if (this.getEss() == 0.0) {
            Arrays.fill(h, 1.0);
        } else {
            for (int j = 0; j < h.length; ++j) {
                h[j] = this.getHyperparameterForHiddenParameter(j);
            }
        }
        DirichletMRGParams param = new DirichletMRGParams(h);
        DirichletMRG.DEFAULT_INSTANCE.generate(this.hiddenPotential, 0, this.hiddenPotential.length, param);
        this.computeHiddenParameter(this.hiddenPotential);
    }

    public boolean isInitialized() {
        int i;
        for (i = 0; i < this.function.length && this.function[i].isInitialized(); ++i) {
        }
        return this.paramRef != null && i == this.function.length;
    }

    public void setParameters(double[] params, int start) {
        int i;
        for (i = 0; i < this.function.length; ++i) {
            this.setParametersForFunction(i, params, start + this.paramRef[i]);
        }
        if (this.optimizeHidden) {
            this.setHiddenParameters(params, start + this.paramRef[i]);
        } else {
            this.norm = this.isNormalized ? 1.0 : -1.0;
        }
    }

    public void initializeHiddenUniformly() {
        int i;
        int c = this.getNumberOfComponents();
        for (i = 0; i < this.function.length; ++i) {
            if (this.function[i] instanceof AbstractMixtureScoringFunction) {
                ((AbstractMixtureScoringFunction)this.function[i]).initializeHiddenUniformly();
                continue;
            }
            if (this.function[i] instanceof NormalizedScoringFunction) {
                ((NormalizedScoringFunction)this.function[i]).initializeHiddenUniformly();
                continue;
            }
            if (!(this.function[i] instanceof DurationScoringFunction)) continue;
            ((DurationScoringFunction)this.function[i]).initializeUniformly();
        }
        if (this.optimizeHidden) {
            double[] pars = new double[c];
            double d = this.freeParams ? this.getNormalizationConstantForComponent(c) : 1.0;
            for (i = 0; i < c; ++i) {
                pars[i] = Math.log(d / this.getNormalizationConstantForComponent(i));
            }
            this.setHiddenParameters(pars, 0);
        }
        this.init(this.freeParams);
    }

    protected void setHiddenParameters(double[] params, int start) {
        int len = this.hiddenParameter.length - (this.freeParams ? 1 : 0);
        double z = 0.0;
        int i = 0;
        while (i < len) {
            this.hiddenParameter[i] = this.logHiddenPotential[i] = params[start];
            this.hiddenPotential[i] = Math.exp(this.logHiddenPotential[i]);
            z += this.hiddenPotential[i];
            ++i;
            ++start;
        }
        if (this.freeParams) {
            this.logHiddenPotential[i] = 0.0;
            this.hiddenParameter[i] = 0.0;
            this.hiddenPotential[i] = 1.0;
            z += 1.0;
        }
        if (this.isNormalized) {
            this.logHiddenNorm = Math.log(z);
            for (i = 0; i < len; ++i) {
                int n = i;
                this.logHiddenPotential[n] = this.logHiddenPotential[n] - this.logHiddenNorm;
                int n2 = i;
                this.hiddenPotential[n2] = this.hiddenPotential[n2] / z;
                this.partNorm[i] = this.hiddenPotential[i];
            }
            if (this.freeParams) {
                int n = i;
                this.logHiddenPotential[n] = this.logHiddenPotential[n] - this.logHiddenNorm;
                int n3 = i;
                this.hiddenPotential[n3] = this.hiddenPotential[n3] / z;
                this.partNorm[i] = this.hiddenPotential[i];
            }
            this.norm = 1.0;
        } else {
            this.norm = -1.0;
        }
    }

    protected void setParametersForFunction(int index, double[] params, int start) {
        this.function[index].setParameters(params, start);
    }

    public final StringBuffer toXML() {
        StringBuffer b = new StringBuffer(10000);
        XMLParser.appendIntWithTags(b, this.length, "length");
        XMLParser.appendIntWithTags(b, this.starts, "starts");
        XMLParser.appendBooleanWithTags(b, this.freeParams, "freeParams");
        XMLParser.appendStorableArrayWithTags(b, this.function, "function");
        XMLParser.appendBooleanWithTags(b, this.optimizeHidden, "optimizeHidden");
        XMLParser.appendBooleanWithTags(b, this.plugIn, "plugIn");
        XMLParser.appendDoubleArrayWithTags(b, this.hiddenParameter, "hiddenParameter");
        b.append(this.getFurtherInformation());
        XMLParser.addTags(b, this.getXMLTag());
        return b;
    }

    protected final void fromXML(StringBuffer b) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(b, this.getXMLTag());
        this.length = XMLParser.extractIntForTag(xml, "length");
        this.starts = XMLParser.extractIntForTag(xml, "starts");
        this.freeParams = XMLParser.extractBooleanForTag(xml, "freeParams");
        this.function = (NormalizableScoringFunction[])ArrayHandler.cast(XMLParser.extractStorableArrayForTag(xml, "function"));
        this.alphabets = this.function[0].getAlphabetContainer();
        this.isNormalized = AbstractMixtureScoringFunction.isNormalized(this.function);
        this.optimizeHidden = XMLParser.extractBooleanForTag(xml, "optimizeHidden");
        this.plugIn = XMLParser.extractBooleanForTag(xml, "plugIn");
        this.hiddenParameter = XMLParser.extractDoubleArrayForTag(xml, "hiddenParameter");
        this.hiddenPotential = new double[this.hiddenParameter.length];
        this.logHiddenPotential = new double[this.hiddenParameter.length];
        this.partNorm = new double[this.logHiddenPotential.length];
        this.setHiddenParameters(this.hiddenParameter, 0);
        this.componentScore = new double[this.logHiddenPotential.length];
        this.extractFurtherInformation(xml);
        this.init(this.freeParams);
        this.computeLogGammaSum();
    }

    protected StringBuffer getFurtherInformation() {
        return new StringBuffer(1);
    }

    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
    }

    protected int[] getIndices(int index) {
        int[] erg = new int[]{0, -1};
        while (index >= this.paramRef[erg[0]]) {
            erg[0] = erg[0] + 1;
        }
        erg[0] = erg[0] - 1;
        erg[1] = index - this.paramRef[erg[0]];
        return erg;
    }

    protected String getXMLTag() {
        return this.getClass().getSimpleName();
    }

    protected void init(boolean freeParams) {
        this.initWithLength(freeParams, this.function.length + 2);
    }

    protected final void initWithLength(boolean freeParams, int len) {
        int i;
        if (this.paramRef == null || this.paramRef.length != len) {
            this.paramRef = new int[len];
        }
        if (this.iList == null) {
            this.iList = new IntList[Math.max(this.function.length, this.hiddenParameter.length)];
            this.dList = new DoubleList[this.iList.length];
            for (i = 0; i < this.iList.length; ++i) {
                this.iList[i] = new IntList();
                this.dList[i] = new DoubleList();
            }
        }
        for (i = 0; i < this.function.length; ++i) {
            int h = this.function[i].getNumberOfParameters();
            if (h == -1) {
                this.paramRef = null;
                return;
            }
            this.paramRef[i + 1] = this.paramRef[i] + this.function[i].getNumberOfParameters();
        }
        this.paramRef[i + 1] = this.optimizeHidden ? this.paramRef[i] + this.hiddenParameter.length - (freeParams ? 1 : 0) : this.paramRef[i];
        this.freeParams = freeParams;
    }

    protected void computeHiddenParameter(double[] statistic) {
        int i;
        for (i = 0; i < this.hiddenParameter.length; ++i) {
            int n = i;
            statistic[n] = statistic[n] + this.getHyperparameterForHiddenParameter(i);
        }
        if (this.freeParams) {
            int j = this.hiddenParameter.length - 1;
            this.hiddenParameter[j] = Math.log(statistic[j]);
            for (i = 0; i < j; ++i) {
                this.hiddenParameter[i] = Math.log(statistic[i]) - this.hiddenParameter[j];
            }
            this.hiddenParameter[j] = 0.0;
        } else {
            double sum = 0.0;
            for (i = 0; i < this.hiddenParameter.length; ++i) {
                sum += statistic[i];
            }
            sum = Math.log(sum);
            for (i = 0; i < this.hiddenParameter.length; ++i) {
                this.hiddenParameter[i] = Math.log(statistic[i]) - sum;
            }
        }
        this.setHiddenParameters(this.hiddenParameter, 0);
    }

    protected void precomputeNorm() {
        this.norm = 0.0;
        for (int i = 0; i < this.logHiddenPotential.length; ++i) {
            this.partNorm[i] = this.hiddenPotential[i] * this.getNormalizationConstantForComponent(i);
            this.norm += this.partNorm[i];
        }
    }

    protected abstract double getNormalizationConstantForComponent(int var1);

    protected abstract void fillComponentScores(Sequence var1, int var2);

    public boolean isNormalized() {
        return this.isNormalized;
    }

    public NormalizableScoringFunction getFunction(int index) throws CloneNotSupportedException {
        return (NormalizableScoringFunction)this.function[index].clone();
    }

    public NormalizableScoringFunction[] getFunctions() throws CloneNotSupportedException {
        return (NormalizableScoringFunction[])ArrayHandler.clone((Cloneable[])this.function);
    }
}

