/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.mix.motifSearch;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.io.XMLParser;
import de.jstacs.scoringFunctions.mix.motifSearch.CDFOfNormal;
import de.jstacs.scoringFunctions.mix.motifSearch.DurationScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

public class SkewNormalLikeScoringFunction
extends DurationScoringFunction {
    private boolean trainMean;
    private boolean trainPrecision;
    private boolean trainSkew;
    private double par0;
    private double par1;
    private double par2;
    private double hyperMeanMean;
    private double hyperMeanStdev;
    private double hyperPrec1;
    private double hyperPrec2;
    private double hyperSkewMean;
    private double hyperSkewStdev;
    private double priorC;
    private double partDerMu;
    private double mu;
    private double sigma;
    private double prec;
    private double logNorm;
    private double partDerLogNormPar0;
    private double partDerLogNormPar1;
    private double partDerLogNormPar2;
    private double[] logScore;
    private double[] densDivCDF;
    private int starts;
    private static RandomNumberGenerator randNumGen = new RandomNumberGenerator();
    private static final double V = 6.72;
    private static final double ONE_DIV_BY_SQRT_OF_2_TIMES_PI = 1.0 / Math.sqrt(Math.PI * 2);

    private SkewNormalLikeScoringFunction(int min, int max, double ess, boolean trainMean, double param0, boolean trainPrecision, double param1, boolean trainSkew, double param2, int starts) {
        super(min, max, ess);
        this.setParameters(param0, param1, param2);
        this.trainMean = trainMean;
        this.trainPrecision = trainPrecision;
        this.trainSkew = trainSkew;
        if (starts < 1) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = starts;
    }

    public SkewNormalLikeScoringFunction(int min, int max, double param0, double param1, double param2, int starts) {
        this(min, max, 0.0, false, param0, false, param1, false, param2, starts);
    }

    public SkewNormalLikeScoringFunction(int min, int max, boolean trainMean, double hyperMeanMean, double hyperMeanSigma, boolean trainPrecision, double hyperPrec1, double hyperPrec2, boolean trainSkew, double hyperSkewMean, double hyperSkewStdev, int starts) {
        this(min, max, 2.0 * hyperPrec1, trainMean, 0.0, trainPrecision, -2.0 * Math.log((double)(max - min) / 4.0), trainSkew, 0.0, starts);
        if (this.ess > 0.0) {
            if (hyperMeanSigma <= 0.0) {
                throw new IllegalArgumentException("The prior of the mean parameter is wrongly specified. (check the second parameter: " + hyperMeanSigma + ")");
            }
            this.hyperMeanMean = hyperMeanMean;
            this.hyperMeanStdev = hyperMeanSigma;
            if (hyperPrec1 <= 0.0 || hyperPrec2 <= 0.0) {
                throw new IllegalArgumentException("The prior of the precision parameter is wrongly specified. (" + hyperPrec1 + ", " + hyperPrec2 + ")");
            }
            this.hyperPrec1 = hyperPrec1;
            this.hyperPrec2 = hyperPrec2;
            if (hyperSkewStdev <= 0.0) {
                throw new IllegalArgumentException("The prior of the skew parameter is wrongly specified. (check the second parameter: " + hyperSkewStdev + ")");
            }
            this.hyperSkewMean = hyperSkewMean;
            this.hyperSkewStdev = hyperSkewStdev;
            this.precomputePriorConstants();
        }
    }

    public SkewNormalLikeScoringFunction(StringBuffer source) throws NonParsableException {
        super(source);
    }

    public SkewNormalLikeScoringFunction clone() throws CloneNotSupportedException {
        SkewNormalLikeScoringFunction clone = (SkewNormalLikeScoringFunction)super.clone();
        if (this.logScore != null) {
            clone.logScore = (double[])this.logScore.clone();
            clone.densDivCDF = (double[])this.densDivCDF.clone();
        }
        return clone;
    }

    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        if (data[index].getAlphabetContainer().checkConsistency(this.alphabets)) {
            int i;
            double w = 1.0;
            double all = 0.0;
            double sum = 0.0;
            double var = 1.0 / (this.hyperMeanStdev * this.hyperMeanStdev);
            for (i = 0; i < data[index].getNumberOfElements(); ++i) {
                if (weights != null && weights[index] != null) {
                    w = weights[index][i];
                }
                sum += w * (double)data[index].getElementAt(i).discreteVal(0);
                all += w;
            }
            this.mu = (this.prec * sum + var * this.hyperMeanMean) / (all * this.prec + var);
            var = 0.0;
            for (i = 0; i < data[index].getNumberOfElements(); ++i) {
                double c = (double)data[index].getElementAt(i).discreteVal(0) - this.mu;
                if (weights != null && weights[index] != null) {
                    w = weights[index][i];
                }
                var += w * c * c;
            }
            this.mu /= (double)this.delta;
            this.setParameters(new double[]{Math.log(this.mu / (1.0 - this.mu)), Math.log((this.hyperPrec1 + 0.5 * all) / (this.hyperPrec2 + 0.5 * var)), 0.0}, 0);
        } else {
            System.out.println("Warning: Try to initialize " + this.getClass().getName() + " with data over another AlphabetContainer.");
            this.initializeFunctionRandomly(freeParams);
        }
    }

    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        double[] init = new double[this.getNumberOfParameters()];
        int i = 0;
        if (this.trainMean) {
            init[i] = r.nextDouble();
            init[i] = 6.72 * init[i] - 3.36;
            ++i;
        }
        if (this.trainPrecision) {
            double drawn;
            double beta;
            double d = this.delta * this.delta;
            double alpha = this.ess > 0.0 ? this.hyperPrec1 : 1.0;
            double d2 = beta = this.ess > 0.0 ? this.hyperPrec2 : 10000.0;
            while ((drawn = randNumGen.nextGamma(alpha, 1.0 / beta)) > 16.0 / d) {
            }
            init[i] = Math.log(drawn);
        }
        this.setParameters(init, 0);
    }

    protected void fromXML(StringBuffer rep) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(rep, this.getInstanceName());
        super.fromXML(xml);
        this.trainMean = XMLParser.extractBooleanForTag(xml, "trainMean");
        this.trainPrecision = XMLParser.extractBooleanForTag(xml, "trainPrecision");
        this.trainSkew = XMLParser.extractBooleanForTag(xml, "trainSkew");
        this.setParameters(XMLParser.extractDoubleForTag(xml, "par0"), XMLParser.extractDoubleForTag(xml, "par1"), XMLParser.extractDoubleForTag(xml, "par2"));
        this.hyperMeanMean = XMLParser.extractDoubleForTag(xml, "hyperMeanMean");
        try {
            this.hyperMeanStdev = XMLParser.extractDoubleForTag(xml, "hyperMeanStdev");
        }
        catch (NonParsableException n) {
            this.hyperMeanStdev = 250.0;
        }
        this.hyperPrec1 = XMLParser.extractDoubleForTag(xml, "hyperPrec1");
        this.hyperPrec2 = XMLParser.extractDoubleForTag(xml, "hyperPrec2");
        this.hyperSkewMean = XMLParser.extractDoubleForTag(xml, "hyperSkewMean");
        this.hyperSkewStdev = XMLParser.extractDoubleForTag(xml, "hyperSkewStdev");
        this.precomputePriorConstants();
        this.starts = XMLParser.extractIntForTag(xml, "starts");
    }

    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    public double[] getCurrentParameterValues() throws Exception {
        double[] init = new double[this.getNumberOfParameters()];
        int i = 0;
        if (this.trainMean) {
            init[i++] = this.par0;
        }
        if (this.trainPrecision) {
            init[i++] = this.par1;
        }
        if (this.trainSkew) {
            init[i] = this.par2;
        }
        return init;
    }

    public double getLogScore(int ... values) {
        return this.logScore[values[0] - this.min] - this.logNorm;
    }

    public double getLogScoreAndPartialDerivation(IntList indices, DoubleList partialDer, int ... values) {
        double z = ((double)values[0] - this.mu) / this.sigma;
        double h = z + this.densDivCDF[values[0] - this.min] * -this.par2;
        int i = 0;
        if (this.trainMean) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar0 + this.partDerMu / this.sigma * h);
        }
        if (this.trainPrecision) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar1 + 0.5 * -z * h);
        }
        if (this.trainSkew) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar2 + this.densDivCDF[values[0] - this.min] * z);
        }
        return this.logScore[values[0] - this.min] - this.logNorm;
    }

    public int getNumberOfParameters() {
        return (this.trainMean ? 1 : 0) + (this.trainPrecision ? 1 : 0) + (this.trainSkew ? 1 : 0);
    }

    public void setParameters(double[] params, int start) {
        this.setParameters(this.trainMean ? params[start] : this.par0, this.trainPrecision ? params[start + (this.trainMean ? 1 : 0)] : this.par1, this.trainSkew ? params[start + (this.trainMean ? 1 : 0) + (this.trainPrecision ? 1 : 0)] : this.par2);
    }

    public void setParameters(double par0, double par1, double par2) {
        this.par0 = par0;
        double expCurrent = Math.exp(par0);
        this.mu = (double)this.min + (double)this.delta * (0.01 * par0 + expCurrent / (1.0 + expCurrent));
        this.partDerMu = (double)this.delta * (0.01 + expCurrent / ((1.0 + expCurrent) * (1.0 + expCurrent)));
        this.par1 = par1;
        this.prec = Math.exp(par1);
        this.sigma = 1.0 / Math.sqrt(this.prec);
        this.par2 = par2;
        if (this.logScore == null || this.logScore.length != this.delta + 1) {
            this.logScore = new double[this.delta + 1];
            this.densDivCDF = new double[this.delta + 1];
        }
        this.partDerLogNormPar2 = 0.0;
        this.partDerLogNormPar1 = 0.0;
        this.partDerLogNormPar0 = 0.0;
        for (int j = 0; j < this.logScore.length; ++j) {
            double diff = (double)(this.min + j) - this.mu;
            double z = diff / this.sigma;
            double zSq = this.prec * diff * diff;
            double partDerPhiPart = ONE_DIV_BY_SQRT_OF_2_TIMES_PI * Math.exp(-0.5 * par2 * par2 * zSq);
            this.densDivCDF[j] = par2 == 0.0 ? Math.log(0.5) : CDFOfNormal.getLogCDF(par2 * z);
            double phi = Math.exp(this.densDivCDF[j]);
            this.logScore[j] = -0.5 * zSq + this.densDivCDF[j];
            expCurrent = Math.exp(-0.5 * zSq);
            this.partDerLogNormPar2 += expCurrent * partDerPhiPart * z;
            this.partDerLogNormPar0 += (expCurrent *= z * phi + partDerPhiPart * -par2);
            this.partDerLogNormPar1 -= expCurrent * z;
            this.densDivCDF[j] = ONE_DIV_BY_SQRT_OF_2_TIMES_PI * Math.exp(-0.5 * par2 * par2 * zSq - this.densDivCDF[j]);
        }
        this.logNorm = Normalisation.getLogSum(this.logScore);
        expCurrent = Math.exp(this.logNorm);
        this.partDerLogNormPar0 = this.partDerLogNormPar0 * this.partDerMu / this.sigma / expCurrent;
        this.partDerLogNormPar1 = this.partDerLogNormPar1 * 0.5 / expCurrent;
        this.partDerLogNormPar2 /= expCurrent;
    }

    public StringBuffer toXML() {
        StringBuffer xml = super.toXML();
        XMLParser.appendBooleanWithTags(xml, this.trainMean, "trainMean");
        XMLParser.appendBooleanWithTags(xml, this.trainPrecision, "trainPrecision");
        XMLParser.appendBooleanWithTags(xml, this.trainSkew, "trainSkew");
        XMLParser.appendDoubleWithTags(xml, this.par0, "par0");
        XMLParser.appendDoubleWithTags(xml, this.par1, "par1");
        XMLParser.appendDoubleWithTags(xml, this.par2, "par2");
        XMLParser.appendDoubleWithTags(xml, this.hyperMeanMean, "hyperMeanMean");
        XMLParser.appendDoubleWithTags(xml, this.hyperMeanStdev, "hyperMeanStdev");
        XMLParser.appendDoubleWithTags(xml, this.hyperPrec1, "hyperPrec1");
        XMLParser.appendDoubleWithTags(xml, this.hyperPrec2, "hyperPrec2");
        XMLParser.appendDoubleWithTags(xml, this.hyperSkewMean, "hyperSkewMean");
        XMLParser.appendDoubleWithTags(xml, this.hyperSkewStdev, "hyperSkewStdev");
        XMLParser.appendIntWithTags(xml, this.starts, "starts");
        XMLParser.addTags(xml, this.getInstanceName());
        return xml;
    }

    protected String getRNotation(String distributionName) {
        return "l = " + this.min + ":" + this.max + "; " + distributionName + " = exp( -0.5 * (l -" + this.mu + ")^2/" + this.sigma + "^2 - " + this.logNorm + " ) * pnorm(" + this.par2 + "*(l-" + this.mu + ")/" + this.sigma + ");";
    }

    public double getLogPriorTerm() {
        double val = this.priorC;
        if (this.ess > 0.0) {
            double h;
            if (this.trainMean) {
                h = (this.mu - this.hyperMeanMean) / this.hyperMeanStdev;
                val -= 0.5 * h * h;
                h = Math.exp(this.par0);
                val += Math.log(0.01 + h / ((1.0 + h) * (1.0 + h)));
            }
            if (this.trainPrecision) {
                val += this.hyperPrec1 * this.par1 - this.prec * this.hyperPrec2;
            }
            if (this.trainSkew) {
                h = (this.par2 - this.hyperSkewMean) / this.hyperSkewStdev;
                val -= 0.5 * h * h;
            }
        }
        return val;
    }

    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        if (this.ess > 0.0) {
            if (this.trainMean) {
                double h = Math.exp(this.par0);
                double h1 = 1.0 + h;
                int n = start++;
                grad[n] = grad[n] + (-(this.mu - this.hyperMeanMean) / (this.hyperMeanStdev * this.hyperMeanStdev) * this.partDerMu + h * (1.0 - h) / h1 / (0.01 * h1 * h1 + h));
            }
            if (this.trainPrecision) {
                int n = start++;
                grad[n] = grad[n] + (this.hyperPrec1 - this.prec * this.hyperPrec2);
            }
            if (this.trainSkew) {
                int n = start++;
                grad[n] = grad[n] + -(this.par2 - this.hyperSkewMean) / (this.hyperSkewStdev * this.hyperSkewStdev);
            }
        }
    }

    public boolean isInitialized() {
        return true;
    }

    public boolean isNormalized() {
        return true;
    }

    public void initializeUniformly() {
        this.setParameters(new double[]{0.0, Double.NEGATIVE_INFINITY, 0.0}, 0);
    }

    public void adjust(int[] length, double[] weight) {
        int i;
        double mu = this.hyperMeanMean;
        double sum = 0.0;
        double precision = 0.0;
        int anz = 0;
        for (i = 0; i < length.length; ++i) {
            if (!(weight[i] > 0.0)) continue;
            ++anz;
            mu += weight[i] * (double)length[i];
            sum += weight[i];
        }
        mu /= sum + 1.0;
        if (anz > 1) {
            for (i = 0; i < length.length; ++i) {
                double c = (double)length[i] - mu;
                precision += weight[i] * c * c;
            }
            precision = (0.5 * sum + this.hyperPrec1) / (0.5 * precision + this.hyperPrec2);
            precision = Math.log(precision);
        } else {
            precision = -10.0;
        }
        mu = (mu - (double)this.min) / (double)this.delta;
        mu = Math.log(mu / (1.0 - mu));
        this.setParameters(new double[]{mu, precision, 0.0}, 0);
        System.out.println(this);
    }

    public void modify(int delta) {
        super.modify(delta);
        this.precomputePriorConstants();
        this.setParameters(this.par0, this.par1, this.par2);
    }

    private void precomputePriorConstants() {
        this.priorC = 0.0;
        if (this.trainMean) {
            this.priorC += -Math.log(Math.sqrt(Math.PI * 2) * this.hyperMeanStdev) + Math.log(this.delta);
        }
        if (this.trainPrecision) {
            this.priorC += this.hyperPrec1 * Math.log(this.hyperPrec2) - Gamma.logOfGamma((double)this.hyperPrec1);
        }
        if (this.trainSkew) {
            this.priorC -= Math.log(Math.sqrt(Math.PI * 2) * this.hyperSkewStdev);
        }
    }

    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }
}

