/*
 * Decompiled with CFR 0.152.
 */
package projects.mspd;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModelFactory;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;

public class MSPDModel
extends AbstractDifferentiableStatisticalModel {
    private MSPDNode root;
    private DifferentiableStatisticalModel[] leafModels;
    private double ess;
    private double t;
    private int maxDepth;

    public MSPDModel(AlphabetContainer alphabets, int length, double ess, double t, int maxDepth) {
        super(alphabets, length);
        this.ess = ess;
        this.t = t;
        this.maxDepth = maxDepth;
    }

    @Override
    public MSPDModel clone() throws CloneNotSupportedException {
        MSPDModel clone = (MSPDModel)super.clone();
        if (this.root != null) {
            clone.root = this.root.clone();
            clone.leafModels = this.root.getAllLeafModels();
        }
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double getLogNormalizationConstant() {
        double logNorm = 0.0;
        int i = 0;
        while (i < this.leafModels.length) {
            logNorm += this.leafModels[i].getLogNormalizationConstant();
            ++i;
        }
        return logNorm;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        int i = 0;
        while (parameterIndex >= this.leafModels[i].getNumberOfParameters()) {
            parameterIndex -= this.leafModels[i].getNumberOfParameters();
            ++i;
        }
        return this.leafModels[i].getLogPartialNormalizationConstant(parameterIndex);
    }

    @Override
    public double getLogPriorTerm() {
        double lp = 0.0;
        int i = 0;
        while (i < this.leafModels.length) {
            lp += this.leafModels[i].getLogPriorTerm();
            ++i;
        }
        return lp;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i = 0;
        while (i < this.leafModels.length) {
            this.leafModels[i].addGradientOfLogPriorTerm(grad, start);
            start += this.leafModels[i].getNumberOfParameters();
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        boolean[][] use = new boolean[data.length][];
        if (weights == null) {
            weights = new double[data.length][];
        }
        int i = 0;
        while (i < use.length) {
            use[i] = new boolean[data[i].getNumberOfElements()];
            Arrays.fill(use[i], true);
            if (weights[i] == null) {
                weights[i] = new double[data[i].getNumberOfElements()];
                Arrays.fill(weights[i], 1.0);
            }
            ++i;
        }
        this.root = new MSPDNode();
        this.root.growTree(data, (double[][])weights, use, this.ess, this.ess, this.t, index, 0, this.maxDepth);
        this.leafModels = this.root.getAllLeafModels();
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        throw new RuntimeException("Not implemented");
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        int leafIdx = this.root.getLeafModel(seq, start);
        IntList itemp = new IntList();
        int off = 0;
        int i = 0;
        while (i < leafIdx) {
            off += this.leafModels[i].getNumberOfParameters();
            ++i;
        }
        double sc = this.leafModels[leafIdx].getLogScoreAndPartialDerivation(seq, start, itemp, partialDer);
        int i2 = 0;
        while (i2 < itemp.length()) {
            indices.add(itemp.get(i2) + off);
            ++i2;
        }
        return sc;
    }

    @Override
    public int getNumberOfParameters() {
        int num = 0;
        int i = 0;
        while (i < this.leafModels.length) {
            num += this.leafModels[i].getNumberOfParameters();
            ++i;
        }
        return num;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        DoubleList pars = new DoubleList();
        int i = 0;
        while (i < this.leafModels.length) {
            double[] temp = this.leafModels[i].getCurrentParameterValues();
            int j = 0;
            while (j < temp.length) {
                pars.add(temp[j]);
                ++j;
            }
            ++i;
        }
        return pars.toArray();
    }

    @Override
    public void setParameters(double[] params, int start) {
        int i = 0;
        while (i < this.leafModels.length) {
            this.leafModels[i].setParameters(params, start);
            start += this.leafModels[i].getNumberOfParameters();
            ++i;
        }
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        int idx = this.root.getLeafModel(seq, start);
        return this.leafModels[idx].getLogScoreFor(seq, start);
    }

    @Override
    public boolean isInitialized() {
        return this.leafModels != null && this.leafModels[0].isInitialized();
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        sb.append(String.valueOf(this.getInstanceName()) + " with " + this.leafModels.length + " leafs\n");
        this.root.append(sb);
        return sb.toString();
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
    }

    private class MSPDNode
    implements Cloneable {
        private MSPDNode left;
        private MSPDNode right;
        private boolean[] nodeDecision;
        private int decisionPosition;
        private DifferentiableStatisticalModel leafModel;

        public MSPDNode clone() throws CloneNotSupportedException {
            MSPDNode clone = (MSPDNode)super.clone();
            if (this.left != null) {
                clone.left = this.left.clone();
            }
            if (this.right != null) {
                clone.right = this.right.clone();
            }
            if (this.leafModel != null) {
                clone.leafModel = (DifferentiableStatisticalModel)this.leafModel.clone();
            }
            if (this.nodeDecision != null) {
                clone.nodeDecision = (boolean[])this.nodeDecision.clone();
            }
            return clone;
        }

        private MSPDNode getLeftNode() {
            return this.left;
        }

        private MSPDNode getRightNode() {
            return this.right;
        }

        public DifferentiableStatisticalModel getLeafModel() {
            return this.leafModel;
        }

        public int getLeafModel(Sequence seq, int start) {
            if (this.leafModel == null) {
                if (this.nodeDecision[seq.discreteVal(start + this.decisionPosition)]) {
                    return this.left.getLeafModel(seq, start);
                }
                return this.left.getNumberOfLeafModels() + this.right.getLeafModel(seq, start);
            }
            return 0;
        }

        public int getNumberOfLeafModels() {
            if (this.leafModel == null) {
                return this.left.getNumberOfLeafModels() + this.right.getNumberOfLeafModels();
            }
            return 1;
        }

        private boolean[] getDecision(int index, int alen) {
            boolean[] temp = new boolean[alen];
            int[] pow = new int[alen];
            int i = 0;
            while (i < pow.length) {
                pow[i] = (int)Math.pow(2.0, i);
                ++i;
            }
            i = pow.length - 1;
            while (i >= 0) {
                if (index >= pow[i]) {
                    temp[i] = true;
                    index -= pow[i];
                }
                --i;
            }
            return temp;
        }

        private void log(double[] vals) {
            int i = 0;
            while (i < vals.length) {
                vals[i] = Math.log(vals[i]);
                ++i;
            }
        }

        public double getCl(boolean[] tempDecision, int tempPos, DataSet[] data, double[][] weights, boolean[][] use, double ess) {
            double[] plr = new double[2];
            int i = 0;
            while (i < tempDecision.length) {
                if (tempDecision[i]) {
                    plr[0] = plr[0] + 1.0;
                } else {
                    plr[1] = plr[1] + 1.0;
                }
                ++i;
            }
            Normalisation.sumNormalisation(plr);
            double[][][] leftCounts = new double[data.length][data[0].getElementLength()][(int)data[0].getAlphabetContainer().getAlphabetLengthAt(0)];
            double[][][] rightCounts = new double[data.length][data[0].getElementLength()][(int)data[0].getAlphabetContainer().getAlphabetLengthAt(0)];
            double[][][] allCounts = new double[data.length][data[0].getElementLength()][(int)data[0].getAlphabetContainer().getAlphabetLengthAt(0)];
            int i2 = 0;
            while (i2 < leftCounts.length) {
                int j = 0;
                while (j < leftCounts[i2].length) {
                    Arrays.fill(leftCounts[i2][j], ess * plr[0] / data[0].getAlphabetContainer().getAlphabetLengthAt(j));
                    Arrays.fill(rightCounts[i2][j], ess * plr[1] / data[0].getAlphabetContainer().getAlphabetLengthAt(j));
                    Arrays.fill(allCounts[i2][j], ess);
                    ++j;
                }
                ++i2;
            }
            double[] pc = new double[data.length];
            double[] pcl = new double[data.length];
            double[] pcr = new double[data.length];
            Arrays.fill(pc, ess);
            int k = 0;
            while (k < data.length) {
                int i3 = 0;
                while (i3 < data[k].getNumberOfElements()) {
                    if (use[k][i3]) {
                        int l;
                        Sequence seq = data[k].getElementAt(i3);
                        if (tempDecision[seq.discreteVal(tempPos)]) {
                            l = 0;
                            while (l < seq.getLength()) {
                                double[] dArray = leftCounts[k][l];
                                int n = seq.discreteVal(l);
                                dArray[n] = dArray[n] + weights[k][i3];
                                ++l;
                            }
                            int n = k;
                            pcl[n] = pcl[n] + weights[k][i3];
                        } else {
                            l = 0;
                            while (l < seq.getLength()) {
                                double[] dArray = rightCounts[k][l];
                                int n = seq.discreteVal(l);
                                dArray[n] = dArray[n] + weights[k][i3];
                                ++l;
                            }
                            int n = k;
                            pcr[n] = pcr[n] + weights[k][i3];
                        }
                        l = 0;
                        while (l < seq.getLength()) {
                            double[] dArray = allCounts[k][l];
                            int n = seq.discreteVal(l);
                            dArray[n] = dArray[n] + weights[k][i3];
                            ++l;
                        }
                        int n = k;
                        pc[n] = pc[n] + weights[k][i3];
                    }
                    ++i3;
                }
                ++k;
            }
            double[] sc = (double[])pc.clone();
            double wt = ToolBox.sum(sc);
            Normalisation.sumNormalisation(pc);
            this.log(pc);
            Normalisation.sumNormalisation(pcl);
            this.log(pcl);
            Normalisation.sumNormalisation(pcr);
            this.log(pcr);
            int i4 = 0;
            while (i4 < leftCounts.length) {
                int l = 0;
                while (l < leftCounts[i4].length) {
                    Normalisation.sumNormalisation(leftCounts[i4][l]);
                    this.log(leftCounts[i4][l]);
                    Normalisation.sumNormalisation(rightCounts[i4][l]);
                    this.log(rightCounts[i4][l]);
                    Normalisation.sumNormalisation(allCounts[i4][l]);
                    this.log(allCounts[i4][l]);
                    ++l;
                }
                ++i4;
            }
            double oldCl = 0.0;
            double newCl = 0.0;
            double[] lls = new double[leftCounts.length];
            double[] lla = new double[allCounts.length];
            int k2 = 0;
            while (k2 < data.length) {
                int i5 = 0;
                while (i5 < data[k2].getNumberOfElements()) {
                    if (use[k2][i5]) {
                        int l;
                        int m;
                        Sequence seq = data[k2].getElementAt(i5);
                        Arrays.fill(lls, 0.0);
                        Arrays.fill(lla, 0.0);
                        if (tempDecision[seq.discreteVal(tempPos)]) {
                            m = 0;
                            while (m < leftCounts.length) {
                                int n = m;
                                lls[n] = lls[n] + pcl[m];
                                l = 0;
                                while (l < seq.getLength()) {
                                    int n2 = m;
                                    lls[n2] = lls[n2] + leftCounts[m][l][seq.discreteVal(l)];
                                    ++l;
                                }
                                ++m;
                            }
                        } else {
                            m = 0;
                            while (m < rightCounts.length) {
                                int n = m;
                                lls[n] = lls[n] + pcr[m];
                                l = 0;
                                while (l < seq.getLength()) {
                                    int n3 = m;
                                    lls[n3] = lls[n3] + rightCounts[m][l][seq.discreteVal(l)];
                                    ++l;
                                }
                                ++m;
                            }
                        }
                        m = 0;
                        while (m < allCounts.length) {
                            int n = m;
                            lla[n] = lla[n] + pc[m];
                            l = 0;
                            while (l < seq.getLength()) {
                                int n4 = m;
                                lla[n4] = lla[n4] + allCounts[m][l][seq.discreteVal(l)];
                                ++l;
                            }
                            ++m;
                        }
                        newCl += (lls[k2] - Normalisation.getLogSum(lls)) / sc[k2] * wt;
                        oldCl += (lla[k2] - Normalisation.getLogSum(lla)) / sc[k2] * wt;
                    }
                    ++i5;
                }
                ++k2;
            }
            return newCl - oldCl;
        }

        public double getEar(boolean[] tempDecision, int tempPos, DataSet[] data, double[][] weights, boolean[][] use, double ess) {
            double[] plr = new double[2];
            int i = 0;
            while (i < tempDecision.length) {
                if (tempDecision[i]) {
                    plr[0] = plr[0] + 1.0;
                } else {
                    plr[1] = plr[1] + 1.0;
                }
                ++i;
            }
            Normalisation.sumNormalisation(plr);
            double[][][] leftCounts = new double[data.length][data[0].getElementLength()][(int)data[0].getAlphabetContainer().getAlphabetLengthAt(0)];
            double[][][] rightCounts = new double[data.length][data[0].getElementLength()][(int)data[0].getAlphabetContainer().getAlphabetLengthAt(0)];
            int i2 = 0;
            while (i2 < leftCounts.length) {
                int j = 0;
                while (j < leftCounts[i2].length) {
                    Arrays.fill(leftCounts[i2][j], ess * plr[0] / data[0].getAlphabetContainer().getAlphabetLengthAt(j));
                    Arrays.fill(rightCounts[i2][j], ess * plr[1] / data[0].getAlphabetContainer().getAlphabetLengthAt(j));
                    ++j;
                }
                ++i2;
            }
            double[] pc = new double[data.length];
            Arrays.fill(pc, ess);
            int k = 0;
            while (k < data.length) {
                int i3 = 0;
                while (i3 < data[k].getNumberOfElements()) {
                    if (use[k][i3]) {
                        int l;
                        Sequence seq = data[k].getElementAt(i3);
                        if (tempDecision[seq.discreteVal(tempPos)]) {
                            l = 0;
                            while (l < seq.getLength()) {
                                double[] dArray = leftCounts[k][l];
                                int n = seq.discreteVal(l);
                                dArray[n] = dArray[n] + weights[k][i3];
                                ++l;
                            }
                        } else {
                            l = 0;
                            while (l < seq.getLength()) {
                                double[] dArray = rightCounts[k][l];
                                int n = seq.discreteVal(l);
                                dArray[n] = dArray[n] + weights[k][i3];
                                ++l;
                            }
                        }
                        int n = k;
                        pc[n] = pc[n] + weights[k][i3];
                    }
                    ++i3;
                }
                ++k;
            }
            Normalisation.sumNormalisation(pc);
            double ear = 0.0;
            int l = 0;
            while (l < leftCounts[0].length) {
                if (l != tempPos) {
                    double[] conMI = new double[data.length];
                    double nonconMI = 0.0;
                    double[] lc = new double[leftCounts[0][l].length];
                    double[] rc = new double[rightCounts[0][l].length];
                    int k2 = 0;
                    while (k2 < data.length) {
                        double norml = ToolBox.sum(leftCounts[k2][l]);
                        double normr = ToolBox.sum(rightCounts[k2][l]);
                        double norm = norml + normr;
                        int a = 0;
                        while (a < leftCounts[k2][l].length) {
                            if (leftCounts[k2][l][a] > 0.0) {
                                int n = k2;
                                conMI[n] = conMI[n] + leftCounts[k2][l][a] / norm * Math.log(leftCounts[k2][l][a] / norm / ((leftCounts[k2][l][a] + rightCounts[k2][l][a]) / norm * norml / norm));
                            }
                            if (rightCounts[k2][l][a] > 0.0) {
                                int n = k2;
                                conMI[n] = conMI[n] + rightCounts[k2][l][a] / norm * Math.log(rightCounts[k2][l][a] / norm / ((leftCounts[k2][l][a] + rightCounts[k2][l][a]) / norm * normr / norm));
                            }
                            int n = a;
                            lc[n] = lc[n] + leftCounts[k2][l][a];
                            int n2 = a;
                            rc[n2] = rc[n2] + rightCounts[k2][l][a];
                            ++a;
                        }
                        ++k2;
                    }
                    double norml = ToolBox.sum(lc);
                    double normr = ToolBox.sum(rc);
                    double norm = norml + normr;
                    int a = 0;
                    while (a < lc.length) {
                        if (lc[a] > 0.0) {
                            nonconMI += lc[a] / norm * Math.log(lc[a] / norm / ((lc[a] + rc[a]) / norm * norml / norm));
                        }
                        if (rc[a] > 0.0) {
                            nonconMI += rc[a] / norm * Math.log(rc[a] / norm / ((lc[a] + rc[a]) / norm * normr / norm));
                        }
                        ++a;
                    }
                    double myEar = 0.0;
                    int k3 = 0;
                    while (k3 < data.length) {
                        myEar += pc[k3] * conMI[k3];
                        ++k3;
                    }
                    if ((myEar -= nonconMI) < 0.0) {
                        myEar = 0.0;
                    }
                    ear += myEar;
                }
                ++l;
            }
            return ear;
        }

        public void growTree(DataSet[] data, double[][] weights, boolean[][] use, double ess, double essLeafModel, double t, int index, int depth, int maxDepth) throws Exception {
            int i;
            double max = Double.NEGATIVE_INFINITY;
            boolean[] decision = null;
            int position = -1;
            int l = 0;
            while (l < data[0].getElementLength()) {
                int i2 = 1;
                while ((double)i2 < Math.pow(2.0, MSPDModel.this.getAlphabetContainer().getAlphabetLengthAt(0)) - 1.0) {
                    boolean[] tempDecision = this.getDecision(i2, (int)MSPDModel.this.getAlphabetContainer().getAlphabetLengthAt(0));
                    double ear = this.getCl(tempDecision, l, data, weights, use, ess);
                    if (ear > max) {
                        max = ear;
                        decision = tempDecision;
                        position = l;
                    }
                    ++i2;
                }
                ++l;
            }
            System.out.println("max: " + max + " " + Arrays.toString(decision) + " " + position);
            if (max > t) {
                this.nodeDecision = decision;
                this.decisionPosition = position;
                this.left = new MSPDNode();
                this.right = new MSPDNode();
                double[] numLeft = new double[data.length];
                double[] numRight = new double[data.length];
                boolean[][] useLeft = (boolean[][])ArrayHandler.clone((Cloneable[])use);
                int k = 0;
                while (k < data.length) {
                    int i3 = 0;
                    while (i3 < data[k].getNumberOfElements()) {
                        if (useLeft[k][i3]) {
                            if (!this.nodeDecision[data[k].getElementAt(i3).discreteVal(this.decisionPosition)]) {
                                useLeft[k][i3] = false;
                            } else {
                                int n = k;
                                numLeft[n] = numLeft[n] + weights[k][i3];
                            }
                        }
                        ++i3;
                    }
                    ++k;
                }
                boolean[][] useRight = (boolean[][])ArrayHandler.clone((Cloneable[])use);
                int k2 = 0;
                while (k2 < data.length) {
                    i = 0;
                    while (i < data[k2].getNumberOfElements()) {
                        if (useRight[k2][i]) {
                            if (this.nodeDecision[data[k2].getElementAt(i).discreteVal(this.decisionPosition)]) {
                                useRight[k2][i] = false;
                            } else {
                                int n = k2;
                                numRight[n] = numRight[n] + weights[k2][i];
                            }
                        }
                        ++i;
                    }
                    ++k2;
                }
                double[] plr = new double[2];
                i = 0;
                while (i < this.nodeDecision.length) {
                    if (this.nodeDecision[i]) {
                        plr[0] = plr[0] + 1.0;
                    } else {
                        plr[1] = plr[1] + 1.0;
                    }
                    ++i;
                }
                Normalisation.sumNormalisation(plr);
                if (ToolBox.min(numLeft) > 0.0 && ToolBox.min(numRight) > 0.0 && depth < maxDepth) {
                    System.out.println(String.valueOf(Arrays.toString(numLeft)) + " " + Arrays.toString(numRight));
                    this.left.growTree(data, weights, useLeft, ess, essLeafModel * plr[0], t, index, depth + 1, maxDepth);
                    this.right.growTree(data, weights, useRight, ess, essLeafModel * plr[1], t, index, depth + 1, maxDepth);
                    return;
                }
                this.left = null;
                this.right = null;
            }
            this.leafModel = DifferentiableStatisticalModelFactory.createPWM(MSPDModel.this.getAlphabetContainer(), MSPDModel.this.length, essLeafModel);
            DataSet[] initData = new DataSet[data.length];
            double[][] initW = new double[weights.length][];
            int k = 0;
            while (k < data.length) {
                LinkedList<Sequence> list = new LinkedList<Sequence>();
                DoubleList w = new DoubleList();
                i = 0;
                while (i < data[k].getNumberOfElements()) {
                    if (use[k][i]) {
                        list.add(data[k].getElementAt(i));
                        w.add(weights[k][i]);
                    }
                    ++i;
                }
                initData[k] = new DataSet("", list);
                initW[k] = w.toArray();
                ++k;
            }
            this.leafModel.initializeFunction(index, false, initData, initW);
        }

        public DifferentiableStatisticalModel[] getAllLeafModels() {
            if (this.leafModel == null && this.left != null && this.right != null) {
                DifferentiableStatisticalModel[] leftModels = this.left.getAllLeafModels();
                DifferentiableStatisticalModel[] rightModels = this.right.getAllLeafModels();
                DifferentiableStatisticalModel[] my = new DifferentiableStatisticalModel[leftModels.length + rightModels.length];
                System.arraycopy(leftModels, 0, my, 0, leftModels.length);
                System.arraycopy(rightModels, 0, my, leftModels.length, rightModels.length);
                return my;
            }
            if (this.leafModel != null) {
                return new DifferentiableStatisticalModel[]{this.leafModel};
            }
            return null;
        }

        public void append(StringBuffer sb) {
            if (this.leafModel == null) {
                sb.append("deciding for position " + this.decisionPosition + "\n");
                sb.append("decision: " + Arrays.toString(this.nodeDecision) + "\n");
                sb.append("Left:\n");
                this.left.append(sb);
                sb.append("\n");
                sb.append("Right:\n");
                this.right.append(sb);
                sb.append("\n");
            } else {
                sb.append("Leaf:\n");
                sb.append(this.leafModel.toString());
            }
        }
    }
}

