/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif;

import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.CDFOfNormal;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.RandomNumberGenerator;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

public class SkewNormalLikeDurationDiffSM
extends DurationDiffSM {
    private boolean trainMean;
    private boolean trainPrecision;
    private boolean trainSkew;
    private double par0;
    private double par1;
    private double par2;
    private double hyperMeanMean;
    private double hyperMeanStdev;
    private double hyperPrec1;
    private double hyperPrec2;
    private double hyperSkewMean;
    private double hyperSkewStdev;
    private double priorC;
    private double partDerMu;
    private double mu;
    private double sigma;
    private double prec;
    private double logNorm;
    private double partDerLogNormPar0;
    private double partDerLogNormPar1;
    private double partDerLogNormPar2;
    private double[] logScore;
    private double[] densDivCDF;
    private int starts;
    private static RandomNumberGenerator randNumGen = new RandomNumberGenerator();
    private static final double V = 6.72;
    private static final double ONE_DIV_BY_SQRT_OF_2_TIMES_PI = 1.0 / Math.sqrt(Math.PI * 2);

    private SkewNormalLikeDurationDiffSM(int min, int max, double ess, boolean trainMean, double param0, boolean trainPrecision, double param1, boolean trainSkew, double param2, int starts) {
        super(min, max, ess);
        this.setParameters(param0, param1, param2);
        this.trainMean = trainMean;
        this.trainPrecision = trainPrecision;
        this.trainSkew = trainSkew;
        if (starts < 1) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = starts;
    }

    public SkewNormalLikeDurationDiffSM(int min, int max, double param0, double param1, double param2) {
        this(min, max, 0.0, false, param0, false, param1, false, param2, 1);
    }

    public SkewNormalLikeDurationDiffSM(int min, int max, boolean trainMean, double hyperMeanMean, double hyperMeanSigma, boolean trainPrecision, double hyperPrec1, double hyperPrec2, boolean trainSkew, double hyperSkewMean, double hyperSkewStdev, int starts) {
        this(min, max, 2.0 * hyperPrec1, trainMean, 0.0, trainPrecision, -2.0 * Math.log((double)(max - min) / 4.0), trainSkew, 0.0, starts);
        if (this.ess > 0.0) {
            if (hyperMeanSigma <= 0.0) {
                throw new IllegalArgumentException("The prior of the mean parameter is wrongly specified. (check the second parameter: " + hyperMeanSigma + ")");
            }
            this.hyperMeanMean = hyperMeanMean;
            this.hyperMeanStdev = hyperMeanSigma;
            if (hyperPrec1 <= 0.0 || hyperPrec2 <= 0.0) {
                throw new IllegalArgumentException("The prior of the precision parameter is wrongly specified. (" + hyperPrec1 + ", " + hyperPrec2 + ")");
            }
            this.hyperPrec1 = hyperPrec1;
            this.hyperPrec2 = hyperPrec2;
            if (hyperSkewStdev <= 0.0) {
                throw new IllegalArgumentException("The prior of the skew parameter is wrongly specified. (check the second parameter: " + hyperSkewStdev + ")");
            }
            this.hyperSkewMean = hyperSkewMean;
            this.hyperSkewStdev = hyperSkewStdev;
            this.precomputePriorConstants();
        }
    }

    public SkewNormalLikeDurationDiffSM(StringBuffer source) throws NonParsableException {
        super(source);
    }

    @Override
    public SkewNormalLikeDurationDiffSM clone() throws CloneNotSupportedException {
        SkewNormalLikeDurationDiffSM clone = (SkewNormalLikeDurationDiffSM)super.clone();
        if (this.logScore != null) {
            clone.logScore = (double[])this.logScore.clone();
            clone.densDivCDF = (double[])this.densDivCDF.clone();
        }
        return clone;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (data[index].getAlphabetContainer().checkConsistency(this.alphabets)) {
            int i;
            double w = 1.0;
            Hashtable<Integer, double[]> hash = new Hashtable<Integer, double[]>();
            DiscreteAlphabet abc = (DiscreteAlphabet)this.alphabets.getAlphabetAt(0);
            for (i = 0; i < data[index].getNumberOfElements(); ++i) {
                double[] weightForVal;
                Integer val = new Integer(abc.getSymbolAt(data[index].getElementAt(i).discreteVal(0)));
                if (weights != null && weights[index] != null) {
                    w = weights[index][i];
                }
                if ((weightForVal = (double[])hash.get(val)) == null) {
                    hash.put(val, new double[]{w});
                    continue;
                }
                weightForVal[0] = weightForVal[0] + w;
            }
            Set s = hash.entrySet();
            Iterator it = s.iterator();
            int[] len = new int[s.size()];
            double[] lenWeights = new double[len.length];
            i = 0;
            while (it.hasNext()) {
                Map.Entry current = it.next();
                len[i] = (Integer)current.getKey();
                lenWeights[i++] = ((double[])current.getValue())[0];
            }
            this.adjust(len, lenWeights);
        } else {
            System.out.println("Warning: Try to initialize " + this.getClass().getName() + " with data over another AlphabetContainer.");
            this.initializeFunctionRandomly(freeParams);
        }
    }

    @Override
    public void adjust(int[] length, double[] weight) {
        int i;
        double mu = this.hyperMeanMean;
        double sum = 0.0;
        double precision = 0.0;
        for (i = 0; i < length.length; ++i) {
            if (weight[i] > 0.0) {
                mu += weight[i] * (double)length[i];
                sum += weight[i];
                continue;
            }
            if (!Double.isNaN(weight[i])) continue;
            throw new IllegalArgumentException("Check the " + i + "-th weight (for length " + length[i] + ")");
        }
        mu /= sum + 1.0;
        for (i = 0; i < length.length; ++i) {
            double c = (double)length[i] - mu;
            precision += weight[i] * c * c;
        }
        precision = (0.5 * sum + this.hyperPrec1) / (0.5 * precision + this.hyperPrec2);
        precision = Math.log(precision);
        mu = (mu - (double)this.min) / (double)this.delta;
        mu = Math.log(mu / (1.0 - mu));
        this.setParameters(new double[]{mu, precision, 0.0}, 0);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        double[] init = new double[this.getNumberOfParameters()];
        int i = 0;
        if (this.trainMean) {
            init[i] = r.nextDouble();
            init[i] = 6.72 * init[i] - 3.36;
            ++i;
        }
        if (this.trainPrecision) {
            double drawn;
            double beta;
            double d = this.delta * this.delta;
            double alpha = this.ess > 0.0 ? this.hyperPrec1 : 1.0;
            double d2 = beta = this.ess > 0.0 ? this.hyperPrec2 : 10000.0;
            while ((drawn = randNumGen.nextGamma(alpha, 1.0 / beta)) > 16.0 / d) {
            }
            init[i] = Math.log(drawn);
            ++i;
        }
        if (this.trainSkew) {
            init[i] = this.hyperSkewMean + r.nextGaussian() * this.hyperSkewStdev * this.hyperSkewStdev;
        }
        this.setParameters(init, 0);
    }

    @Override
    protected void fromXML(StringBuffer rep) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(rep, this.getInstanceName());
        super.fromXML(xml);
        this.trainMean = XMLParser.extractObjectForTags(xml, "trainMean", Boolean.TYPE);
        this.trainPrecision = XMLParser.extractObjectForTags(xml, "trainPrecision", Boolean.TYPE);
        this.trainSkew = XMLParser.extractObjectForTags(xml, "trainSkew", Boolean.TYPE);
        this.setParameters(XMLParser.extractObjectForTags(xml, "par0", Double.TYPE), XMLParser.extractObjectForTags(xml, "par1", Double.TYPE), XMLParser.extractObjectForTags(xml, "par2", Double.TYPE));
        this.hyperMeanMean = XMLParser.extractObjectForTags(xml, "hyperMeanMean", Double.TYPE);
        try {
            this.hyperMeanStdev = XMLParser.extractObjectForTags(xml, "hyperMeanStdev", Double.TYPE);
        }
        catch (NonParsableException n) {
            this.hyperMeanStdev = 250.0;
        }
        this.hyperPrec1 = XMLParser.extractObjectForTags(xml, "hyperPrec1", Double.TYPE);
        this.hyperPrec2 = XMLParser.extractObjectForTags(xml, "hyperPrec2", Double.TYPE);
        this.hyperSkewMean = XMLParser.extractObjectForTags(xml, "hyperSkewMean", Double.TYPE);
        this.hyperSkewStdev = XMLParser.extractObjectForTags(xml, "hyperSkewStdev", Double.TYPE);
        this.precomputePriorConstants();
        this.starts = XMLParser.extractObjectForTags(xml, "starts", Integer.TYPE);
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] init = new double[this.getNumberOfParameters()];
        int i = 0;
        if (this.trainMean) {
            init[i++] = this.par0;
        }
        if (this.trainPrecision) {
            init[i++] = this.par1;
        }
        if (this.trainSkew) {
            init[i] = this.par2;
        }
        return init;
    }

    @Override
    public double getLogScore(int ... values) {
        return this.logScore[values[0] - this.min] - this.logNorm;
    }

    @Override
    public double getLogScoreAndPartialDerivation(IntList indices, DoubleList partialDer, int ... values) {
        double z = ((double)values[0] - this.mu) / this.sigma;
        double h = z + this.densDivCDF[values[0] - this.min] * -this.par2;
        int i = 0;
        if (this.trainMean) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar0 + this.partDerMu / this.sigma * h);
        }
        if (this.trainPrecision) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar1 + 0.5 * -z * h);
        }
        if (this.trainSkew) {
            indices.add(i++);
            partialDer.add(-this.partDerLogNormPar2 + this.densDivCDF[values[0] - this.min] * z);
        }
        return this.logScore[values[0] - this.min] - this.logNorm;
    }

    @Override
    public int getNumberOfParameters() {
        return (this.trainMean ? 1 : 0) + (this.trainPrecision ? 1 : 0) + (this.trainSkew ? 1 : 0);
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.setParameters(this.trainMean ? params[start] : this.par0, this.trainPrecision ? params[start + (this.trainMean ? 1 : 0)] : this.par1, this.trainSkew ? params[start + (this.trainMean ? 1 : 0) + (this.trainPrecision ? 1 : 0)] : this.par2);
    }

    public void setParameters(double par0, double par1, double par2) {
        this.par0 = par0;
        double expCurrent = Math.exp(par0);
        this.mu = (double)this.min + (double)this.delta * (0.01 * par0 + expCurrent / (1.0 + expCurrent));
        this.partDerMu = (double)this.delta * (0.01 + expCurrent / ((1.0 + expCurrent) * (1.0 + expCurrent)));
        this.par1 = par1;
        this.prec = Math.exp(par1);
        this.sigma = 1.0 / Math.sqrt(this.prec);
        this.par2 = par2;
        if (this.logScore == null || this.logScore.length != this.delta + 1) {
            this.logScore = new double[this.delta + 1];
            this.densDivCDF = new double[this.delta + 1];
        }
        this.partDerLogNormPar2 = 0.0;
        this.partDerLogNormPar1 = 0.0;
        this.partDerLogNormPar0 = 0.0;
        for (int j = 0; j < this.logScore.length; ++j) {
            double diff = (double)(this.min + j) - this.mu;
            double z = diff / this.sigma;
            double zSq = this.prec * diff * diff;
            double partDerPhiPart = ONE_DIV_BY_SQRT_OF_2_TIMES_PI * Math.exp(-0.5 * par2 * par2 * zSq);
            this.densDivCDF[j] = par2 == 0.0 ? Math.log(0.5) : CDFOfNormal.getLogCDF(par2 * z);
            double phi = Math.exp(this.densDivCDF[j]);
            this.logScore[j] = -0.5 * zSq + this.densDivCDF[j];
            expCurrent = Math.exp(-0.5 * zSq);
            this.partDerLogNormPar2 += expCurrent * partDerPhiPart * z;
            this.partDerLogNormPar0 += (expCurrent *= z * phi + partDerPhiPart * -par2);
            this.partDerLogNormPar1 -= expCurrent * z;
            this.densDivCDF[j] = ONE_DIV_BY_SQRT_OF_2_TIMES_PI * Math.exp(-0.5 * par2 * par2 * zSq - this.densDivCDF[j]);
        }
        this.logNorm = Normalisation.getLogSum(this.logScore);
        expCurrent = Math.exp(this.logNorm);
        this.partDerLogNormPar0 = this.partDerLogNormPar0 * this.partDerMu / this.sigma / expCurrent;
        this.partDerLogNormPar1 = this.partDerLogNormPar1 * 0.5 / expCurrent;
        this.partDerLogNormPar2 /= expCurrent;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = super.toXML();
        XMLParser.appendObjectWithTags(xml, this.trainMean, "trainMean");
        XMLParser.appendObjectWithTags(xml, this.trainPrecision, "trainPrecision");
        XMLParser.appendObjectWithTags(xml, this.trainSkew, "trainSkew");
        XMLParser.appendObjectWithTags(xml, this.par0, "par0");
        XMLParser.appendObjectWithTags(xml, this.par1, "par1");
        XMLParser.appendObjectWithTags(xml, this.par2, "par2");
        XMLParser.appendObjectWithTags(xml, this.hyperMeanMean, "hyperMeanMean");
        XMLParser.appendObjectWithTags(xml, this.hyperMeanStdev, "hyperMeanStdev");
        XMLParser.appendObjectWithTags(xml, this.hyperPrec1, "hyperPrec1");
        XMLParser.appendObjectWithTags(xml, this.hyperPrec2, "hyperPrec2");
        XMLParser.appendObjectWithTags(xml, this.hyperSkewMean, "hyperSkewMean");
        XMLParser.appendObjectWithTags(xml, this.hyperSkewStdev, "hyperSkewStdev");
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        XMLParser.addTags(xml, this.getInstanceName());
        return xml;
    }

    @Override
    protected String getRNotation(String distributionName) {
        return "l = " + this.min + ":" + this.max + "; " + distributionName + " = exp( -0.5 * (l -" + this.mu + ")^2/" + this.sigma + "^2 - " + this.logNorm + " ) * pnorm(" + this.par2 + "*(l-" + this.mu + ")/" + this.sigma + ");";
    }

    @Override
    public double getLogPriorTerm() {
        double val = this.priorC;
        if (this.ess > 0.0) {
            double h;
            if (this.trainMean) {
                h = (this.mu - this.hyperMeanMean) / this.hyperMeanStdev;
                val -= 0.5 * h * h;
                h = Math.exp(this.par0);
                val += Math.log(0.01 + h / ((1.0 + h) * (1.0 + h)));
            }
            if (this.trainPrecision) {
                val += this.hyperPrec1 * this.par1 - this.prec * this.hyperPrec2;
            }
            if (this.trainSkew) {
                h = (this.par2 - this.hyperSkewMean) / this.hyperSkewStdev;
                val -= 0.5 * h * h;
            }
        }
        return val;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        if (this.ess > 0.0) {
            if (this.trainMean) {
                double h = Math.exp(this.par0);
                double h1 = 1.0 + h;
                int n = start++;
                grad[n] = grad[n] + (-(this.mu - this.hyperMeanMean) / (this.hyperMeanStdev * this.hyperMeanStdev) * this.partDerMu + h * (1.0 - h) / h1 / (0.01 * h1 * h1 + h));
            }
            if (this.trainPrecision) {
                int n = start++;
                grad[n] = grad[n] + (this.hyperPrec1 - this.prec * this.hyperPrec2);
            }
            if (this.trainSkew) {
                int n = start++;
                grad[n] = grad[n] + -(this.par2 - this.hyperSkewMean) / (this.hyperSkewStdev * this.hyperSkewStdev);
            }
        }
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public void initializeUniformly() {
        this.setParameters(new double[]{0.0, Double.NEGATIVE_INFINITY, 0.0}, 0);
    }

    @Override
    public void modify(int delta) {
        super.modify(delta);
        this.precomputePriorConstants();
        this.setParameters(this.par0, this.par1, this.par2);
    }

    private void precomputePriorConstants() {
        this.priorC = 0.0;
        if (this.trainMean) {
            this.priorC += -Math.log(Math.sqrt(Math.PI * 2) * this.hyperMeanStdev) + Math.log(this.delta);
        }
        if (this.trainPrecision) {
            this.priorC += this.hyperPrec1 * Math.log(this.hyperPrec2) - Gamma.logOfGamma(this.hyperPrec1);
        }
        if (this.trainSkew) {
            this.priorC -= Math.log(Math.sqrt(Math.PI * 2) * this.hyperSkewStdev);
        }
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }
}

