/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels;

import de.jstacs.InstantiableFromParameterSet;
import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.graphs.TopSort;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.ParameterSetParser;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.InstanceParameterSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BNDiffSMParameter;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BNDiffSMParameterTree;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BayesianNetworkDiffSMParameterSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.structureLearning.measures.Measure;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;

public class BayesianNetworkDiffSM
extends AbstractDifferentiableStatisticalModel
implements InstantiableFromParameterSet {
    protected BNDiffSMParameter[] parameters;
    protected BNDiffSMParameterTree[] trees;
    protected boolean isTrained;
    protected double ess;
    protected Integer numFreePars;
    protected int[] nums;
    protected Measure structureMeasure;
    protected boolean plugInParameters;
    protected int[][] order;
    protected Double logNormalizationConstant;
    private int[] roots;
    private boolean freeParams;
    private Double gammaNorm;
    private BayesianNetworkDiffSMParameterSet parameterSet;

    public BayesianNetworkDiffSM(AlphabetContainer alphabet, int length, double ess, boolean plugInParameters, Measure structureMeasure) throws Exception {
        super(alphabet, length);
        if (!alphabet.isDiscrete()) {
            throw new Exception("Only defined on discrete alphabets.");
        }
        if (length <= 0) {
            throw new Exception("Inconsistent length (" + length + ").");
        }
        this.isTrained = false;
        this.ess = ess;
        this.plugInParameters = plugInParameters;
        this.structureMeasure = structureMeasure;
        this.logNormalizationConstant = null;
    }

    public BayesianNetworkDiffSM(BayesianNetworkDiffSMParameterSet parameters) throws ParameterSetParser.NotInstantiableException, Exception {
        this(parameters.getAlphabetContainer(), parameters.getLength(), parameters.getEss(), parameters.getPlugInParameters(), parameters.getMeasure());
        this.parameterSet = parameters;
    }

    public BayesianNetworkDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.logNormalizationConstant = null;
        this.gammaNorm = null;
    }

    @Override
    public BayesianNetworkDiffSM clone() throws CloneNotSupportedException {
        BayesianNetworkDiffSM clone = (BayesianNetworkDiffSM)super.clone();
        if (this.trees != null) {
            clone.trees = new BNDiffSMParameterTree[this.trees.length];
            for (int i = 0; i < this.trees.length; ++i) {
                clone.trees[i] = this.trees[i].clone();
            }
            LinkedList[] parTemp = new LinkedList[this.trees.length];
            int num = 0;
            for (int i = 0; i < clone.trees.length; ++i) {
                parTemp[i] = clone.trees[i].linearizeParameters();
                num += parTemp[i].size();
            }
            clone.parameters = new BNDiffSMParameter[num];
            num = 0;
            Iterator it = null;
            for (int i = 0; i < parTemp.length; ++i) {
                it = parTemp[i].iterator();
                while (it.hasNext()) {
                    clone.parameters[num++] = (BNDiffSMParameter)it.next();
                }
            }
            clone.nums = (int[])this.nums.clone();
        } else {
            clone.trees = null;
            clone.nums = null;
            clone.parameters = null;
        }
        clone.structureMeasure = this.structureMeasure.clone();
        clone.logNormalizationConstant = null;
        return clone;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        if (this.logNormalizationConstant == null) {
            this.precomputeNormalization();
        }
        if (parameterIndex < this.nums.length) {
            BNDiffSMParameter p = this.parameters[this.nums[parameterIndex]];
            int pos = p.getPosition();
            boolean[] notRoot = new boolean[this.trees.length];
            for (int i = 0; i < this.trees.length; ++i) {
                if (this.trees[i].getNumberOfParents() <= 0) continue;
                notRoot[i] = true;
            }
            double val = p.getLogPartialNormalizer();
            for (int i = 0; i < notRoot.length; ++i) {
                if (notRoot[i] || i == this.roots[pos]) continue;
                val += this.trees[i].forward(this.trees);
            }
            return val;
        }
        throw new Exception("BNDiffSMParameter index out of bounds");
    }

    private int[][] getFirstChildrenAndFirstParents(int[][] parents) throws Exception {
        int j;
        int i;
        LinkedList[] fc = new LinkedList[parents.length];
        boolean test = false;
        int i2 = 0;
        while (i2 < fc.length) {
            fc[i2++] = new LinkedList();
        }
        int[][] erg = new int[parents.length + 1][];
        erg[parents.length] = new int[parents.length];
        Arrays.fill(erg[parents.length], -1);
        for (i = 0; i < parents.length; ++i) {
            test = parents[i].length < 2;
            for (j = 0; j < parents[i].length - 1; ++j) {
                if (!this.testInclude(parents[i], parents[parents[i][j]])) continue;
                fc[parents[i][j]].add(i);
                erg[parents.length][i] = parents[i][j];
                test = true;
            }
            if (test) continue;
            throw new Exception("Structure is no moral graph!");
        }
        for (i = 0; i < fc.length; ++i) {
            erg[i] = new int[fc[i].size()];
            for (j = 0; j < erg[i].length; ++j) {
                erg[i][j] = (Integer)fc[i].poll();
            }
        }
        return erg;
    }

    private boolean testInclude(int[] parentsOfChild, int[] parentsOfParent) {
        for (int i = 0; i < parentsOfChild.length - 1; ++i) {
            boolean found = false;
            for (int j = 0; j < parentsOfParent.length; ++j) {
                if (parentsOfChild[i] != parentsOfParent[j]) continue;
                found = true;
                break;
            }
            if (found) continue;
            return false;
        }
        return true;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (data[index] != null && data[index].getElementLength() != this.length) {
            throw new Exception("Data has wrong length.");
        }
        this.freeParams = freeParams;
        if (weights == null) {
            weights = new double[data.length][];
            for (int i = 0; i < data.length; ++i) {
                weights[i] = new double[data[i].getNumberOfElements()];
                Arrays.fill(weights[i], 1.0);
            }
        }
        DataSet[] data2 = data;
        Object weights2 = weights;
        if (data.length != 2) {
            data2 = new DataSet[2];
            weights2 = new double[2][];
            data2[0] = data[index];
            weights2[0] = weights[index];
            boolean[] in = new boolean[data.length];
            Arrays.fill(in, true);
            in[index] = false;
            data2[1] = DataSet.union(data, in);
            if (data2[1] != null) {
                weights2[1] = new double[data2[1].getNumberOfElements()];
                int off = 0;
                for (int i = 0; i < ((double[][])weights).length; ++i) {
                    if (!in[i]) continue;
                    System.arraycopy(weights[i], 0, weights2[1], off, weights[i].length);
                    off += weights[i].length;
                }
            } else {
                weights2[1] = null;
            }
        }
        this.createTrees(data2, (double[][])weights2);
        if (this.plugInParameters) {
            this.setPlugInParameters(index, freeParams, data, (double[][])weights);
        } else {
            for (int i = 0; i < this.parameters.length; ++i) {
                if (freeParams) {
                    this.parameters[i].setValue(0.0);
                    continue;
                }
                this.parameters[i].setValue(-Math.log(this.alphabets.getAlphabetLengthAt(this.parameters[i].position)));
            }
        }
        this.isTrained = true;
        this.logNormalizationConstant = null;
    }

    protected void createTrees(DataSet[] data2, double[][] weights2) throws Exception {
        int i;
        int[][] parents = this.structureMeasure.getParents(data2[0], data2[1], weights2[0], weights2[1], this.getLength());
        this.order = TopSort.getTopologicalOrder(parents);
        int[][] firstChildrenAndFirstParents = this.getFirstChildrenAndFirstParents(parents);
        this.numFreePars = 0;
        int numPars = 0;
        int[] numContextsPos = new int[this.getLength()];
        this.trees = new BNDiffSMParameterTree[this.getLength()];
        int numContexts = 0;
        int[] contextPoss = null;
        for (int i2 = 0; i2 < parents.length; ++i2) {
            int j;
            numContextsPos[i2] = 1;
            for (j = 0; j < parents[i2].length - 1; ++j) {
                int n = i2;
                numContextsPos[n] = (int)((double)numContextsPos[n] * this.alphabets.getAlphabetLengthAt(parents[i2][j]));
            }
            this.numFreePars = this.numFreePars + numContextsPos[i2] * ((int)this.getAlphabetContainer().getAlphabetLengthAt(i2) - 1);
            numPars = (int)((double)numPars + (double)numContextsPos[i2] * this.getAlphabetContainer().getAlphabetLengthAt(i2));
            numContexts += numContextsPos[i2];
            contextPoss = new int[parents[i2].length - 1];
            for (j = 0; j < contextPoss.length; ++j) {
                contextPoss[j] = parents[i2][parents[i2].length - j - 2];
            }
            this.trees[i2] = new BNDiffSMParameterTree(i2, contextPoss, this.getAlphabetContainer(), firstChildrenAndFirstParents[parents.length][i2], firstChildrenAndFirstParents[i2]);
        }
        this.parameters = new BNDiffSMParameter[numPars];
        if (!this.freeParams) {
            this.numFreePars = numPars;
        }
        this.nums = new int[this.numFreePars.intValue()];
        int curr = 0;
        int free = 0;
        int[][][][] contexts = new int[this.getLength()][][][];
        for (i = 0; i < parents.length; ++i) {
            contexts[i] = new int[numContextsPos[i]][parents[i].length - 1][];
            this.fillContexts(0, contexts[i], 0, parents[i]);
            for (int j = 0; j < contexts[i].length; ++j) {
                int all = 1;
                int act = 1;
                for (int k = 0; k < contexts[i][j].length; ++k) {
                    all = (int)((double)all * this.getAlphabetContainer().getAlphabetLengthAt(contexts[i][j][k][0]));
                    act *= contexts[i][j][k].length - 1;
                }
                byte a = 0;
                while ((double)a < this.getAlphabetContainer().getAlphabetLengthAt(i)) {
                    this.parameters[curr] = (double)a < this.getAlphabetContainer().getAlphabetLengthAt(i) - 1.0 || !this.freeParams ? new BNDiffSMParameter(free, a, i, contexts[i][j], (double)act * this.ess / ((double)all * this.getAlphabetContainer().getAlphabetLengthAt(i)), true) : new BNDiffSMParameter(-1, a, i, contexts[i][j], (double)act * this.ess / ((double)all * this.getAlphabetContainer().getAlphabetLengthAt(i)), false);
                    this.trees[i].setParameterFor(a, contexts[i][j], this.parameters[curr]);
                    if (this.parameters[curr].isFree()) {
                        this.nums[free++] = curr;
                    }
                    ++curr;
                    a = (byte)(a + 1);
                }
            }
        }
        this.roots = new int[this.trees.length];
        for (i = 0; i < this.trees.length; ++i) {
            int fp = i;
            while (this.trees[fp].getFirstParent() != -1) {
                fp = this.trees[fp].getFirstParent();
            }
            this.roots[i] = fp;
        }
        this.logNormalizationConstant = null;
        this.gammaNorm = null;
    }

    protected void setPlugInParameters(int index, boolean freeParameters, DataSet[] data, double[][] weights) {
        if (data[index] != null) {
            int i;
            for (i = 0; i < data[index].getNumberOfElements(); ++i) {
                for (int j = 0; j < this.trees.length; ++j) {
                    this.trees[j].addCount(data[index].getElementAt(i), 0, weights[index][i]);
                }
            }
            for (i = 0; i < this.trees.length; ++i) {
                this.trees[i].normalizePlugInParameters();
                if (!freeParameters) continue;
                this.trees[i].divideByUnfree();
            }
        }
    }

    private int fillContexts(int offset, int[][][] contexts, int depth, int[] parents) {
        int tempOffset = offset;
        if (depth < parents.length - 1) {
            int i = 0;
            while ((double)i < this.alphabets.getAlphabetLengthAt(parents[parents.length - depth - 2])) {
                offset = tempOffset;
                tempOffset = this.fillContexts(offset, contexts, depth + 1, parents);
                for (int j = offset; j < tempOffset; ++j) {
                    contexts[j][depth] = new int[]{parents[parents.length - depth - 2], i};
                }
                ++i;
            }
            return tempOffset;
        }
        return offset + 1;
    }

    @Override
    protected void fromXML(StringBuffer source) throws NonParsableException {
        source = XMLParser.extractForTag(source, "bayesianNetworkSF");
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(source, "alphabets");
        this.length = XMLParser.extractObjectForTags(source, "length", Integer.TYPE);
        this.trees = XMLParser.extractObjectForTags(source, "trees", BNDiffSMParameterTree[].class);
        if (this.trees.length == 0) {
            this.trees = null;
            this.parameters = null;
        } else {
            LinkedList[] parTemp = new LinkedList[this.trees.length];
            int num = 0;
            for (int i = 0; i < this.trees.length; ++i) {
                this.trees[i].setAlphabet(this.alphabets);
                parTemp[i] = this.trees[i].linearizeParameters();
                num += parTemp[i].size();
            }
            this.parameters = new BNDiffSMParameter[num];
            num = 0;
            Iterator it = null;
            for (int i = 0; i < parTemp.length; ++i) {
                it = parTemp[i].iterator();
                while (it.hasNext()) {
                    this.parameters[num++] = (BNDiffSMParameter)it.next();
                }
            }
        }
        this.isTrained = XMLParser.extractObjectForTags(source, "isTrained", Boolean.TYPE);
        this.ess = XMLParser.extractObjectForTags(source, "ess", Double.TYPE);
        this.numFreePars = XMLParser.extractObjectForTags(source, "numFreePars", Integer.class);
        this.nums = XMLParser.extractObjectForTags(source, "nums", int[].class);
        if (this.nums.length == 0) {
            this.nums = null;
        }
        this.structureMeasure = XMLParser.extractObjectForTags(source, "structureMeasure", Measure.class);
        this.order = XMLParser.extractObjectForTags(source, "order", int[][].class);
        if (this.order.length == 0) {
            this.order = null;
        }
        this.plugInParameters = XMLParser.extractObjectForTags(source, "plugInParameters", Boolean.TYPE);
        this.roots = XMLParser.extractObjectForTags(source, "roots", int[].class);
        if (this.roots.length == 0) {
            this.roots = null;
        }
        this.freeParams = XMLParser.extractObjectForTags(source, "freeParams", Boolean.TYPE);
    }

    public String toString() {
        if (this.trees != null) {
            if (this.logNormalizationConstant == null) {
                this.precomputeNormalization();
            }
            StringBuffer buf = new StringBuffer();
            for (int i = 0; i < this.trees.length; ++i) {
                buf.append(this.trees[i].toString());
                buf.append("\n");
            }
            return buf.toString();
        }
        return this.getClass().getSimpleName() + " of length " + this.length + ": not initialized";
    }

    @Override
    public String getInstanceName() {
        return this.structureMeasure.getInstanceName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double prob = 0.0;
        for (int i = 0; i < this.trees.length; ++i) {
            prob += this.trees[i].getParameterFor(seq, start).getValue();
        }
        return prob;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double logScore = 0.0;
        for (int i = 0; i < this.trees.length; ++i) {
            BNDiffSMParameter par = this.trees[i].getParameterFor(seq, start);
            if (par.isFree()) {
                indices.add(par.getIndex());
                partialDer.add(1.0);
            }
            logScore += par.getValue();
        }
        return logScore;
    }

    @Override
    public double getLogNormalizationConstant() throws RuntimeException {
        if (this.logNormalizationConstant == null) {
            this.precomputeNormalization();
        }
        return this.logNormalizationConstant;
    }

    @Override
    public int getNumberOfParameters() {
        if (this.nums == null) {
            return -1;
        }
        return this.nums.length;
    }

    @Override
    public void setParameters(double[] params, int start) {
        for (int i = 0; i < this.nums.length; ++i) {
            this.parameters[this.nums[i]].setValue(params[i + start]);
        }
        this.logNormalizationConstant = null;
    }

    protected void precomputeNormalization() {
        int i;
        boolean[] notRoot = new boolean[this.trees.length];
        for (i = 0; i < this.trees.length; ++i) {
            this.trees[i].invalidateNormalizers();
        }
        for (i = 0; i < this.trees.length; ++i) {
            if (this.trees[i].getNumberOfParents() > 0) {
                notRoot[i] = true;
            }
            if (!this.trees[i].isLeaf()) continue;
            this.trees[i].backward(this.trees, this.order);
        }
        double val = 0.0;
        for (int i2 = 0; i2 < notRoot.length; ++i2) {
            if (notRoot[i2]) continue;
            val += this.trees[i2].forward(this.trees);
        }
        this.logNormalizationConstant = val;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] pars = new double[this.nums.length];
        for (int i = 0; i < pars.length; ++i) {
            pars[i] = this.parameters[this.nums[i]].getValue();
        }
        return pars;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer source = new StringBuffer();
        XMLParser.appendObjectWithTags(source, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(source, this.length, "length");
        XMLParser.appendObjectWithTags(source, this.trees, "trees");
        XMLParser.appendObjectWithTags(source, this.isTrained, "isTrained");
        XMLParser.appendObjectWithTags(source, this.ess, "ess");
        XMLParser.appendObjectWithTags(source, this.numFreePars, "numFreePars");
        if (this.nums == null) {
            XMLParser.appendObjectWithTags(source, new int[0], "nums");
        } else {
            XMLParser.appendObjectWithTags(source, this.nums, "nums");
        }
        XMLParser.appendObjectWithTags(source, this.structureMeasure, "structureMeasure");
        XMLParser.appendObjectWithTags(source, this.plugInParameters, "plugInParameters");
        if (this.order == null) {
            XMLParser.appendObjectWithTags(source, new int[0][0], "order");
        } else {
            XMLParser.appendObjectWithTags(source, this.order, "order");
        }
        if (this.roots == null) {
            XMLParser.appendObjectWithTags(source, new int[0], "roots");
        } else {
            XMLParser.appendObjectWithTags(source, this.roots, "roots");
        }
        XMLParser.appendObjectWithTags(source, this.freeParams, "freeParams");
        XMLParser.addTags(source, "bayesianNetworkSF");
        return source;
    }

    @Override
    public double getLogPriorTerm() {
        if (this.gammaNorm == null) {
            this.computeGammaNorm();
        }
        double val = 0.0;
        for (int i = 0; i < this.nums.length; ++i) {
            val += this.parameters[this.nums[i]].getValue() * this.parameters[this.nums[i]].getPseudoCount();
        }
        return val + this.gammaNorm;
    }

    private void computeGammaNorm() {
        this.gammaNorm = 0.0;
        for (int i = 0; i < this.trees.length; ++i) {
            this.gammaNorm = this.gammaNorm + this.trees[i].computeGammaNorm();
        }
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        for (int i = 0; i < this.nums.length; ++i) {
            int n = i + start;
            grad[n] = grad[n] + this.parameters[this.nums[i]].getPseudoCount();
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    public int getPositionForParameter(int index) {
        return this.parameters[this.nums[index]].getPosition();
    }

    public double[] getPositionDependentKMerProb(Sequence kmer) throws Exception {
        if (!(this.structureMeasure instanceof InhomogeneousMarkov)) {
            throw new Exception("Only implemented for IMMs");
        }
        this.precomputeNormalization();
        for (int i = 0; i < this.trees.length; ++i) {
            this.trees[i].normalizeParameters();
        }
        this.precomputeNormalization();
        double[] prof = new double[this.trees.length - kmer.getLength() + 1];
        Arrays.fill(prof, 1.0);
        for (int i = 0; i < this.trees.length - kmer.getLength() + 1; ++i) {
            int n = i;
            prof[n] = prof[n] * this.trees[i + kmer.getLength() - 1].getProbFor(kmer);
        }
        return prof;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        BNDiffSMParameter par = this.parameters[this.nums[index]];
        int size = (int)this.alphabets.getAlphabetLengthAt(par.getPosition());
        int[][] cont = par.context;
        for (int i = 0; i < cont.length; ++i) {
            size = (int)((double)size * this.alphabets.getAlphabetLengthAt(cont[i][0]));
        }
        return size;
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        if (!(this.structureMeasure instanceof InhomogeneousMarkov)) {
            throw new Exception("Not implemented");
        }
        this.freeParams = freeParams;
        boolean temp = this.plugInParameters;
        this.plugInParameters = false;
        this.initializeFunction(0, freeParams, new DataSet[]{null, null}, new double[][]{null, null});
        this.plugInParameters = temp;
        for (int i = 0; i < this.trees.length; ++i) {
            this.trees[i].initializeRandomly(this.ess);
        }
    }

    @Override
    public boolean isInitialized() {
        return this.trees != null;
    }

    public double[][] getPWM() throws Exception {
        if (!(this.structureMeasure instanceof InhomogeneousMarkov)) {
            throw new Exception();
        }
        if (this.logNormalizationConstant == null) {
            this.precomputeNormalization();
        }
        double[][] pwm = new double[this.trees.length][];
        for (int i = 0; i < this.trees.length; ++i) {
            pwm[i] = new double[(int)this.alphabets.getAlphabetLengthAt(i)];
            this.trees[i].insertProbs(pwm[i]);
        }
        return pwm;
    }

    public InstanceParameterSet getCurrentParameterSet() throws Exception {
        if (this.parameterSet != null) {
            return this.parameterSet;
        }
        return new BayesianNetworkDiffSMParameterSet(this.getAlphabetContainer(), this.getLength(), this.ess, this.plugInParameters, this.structureMeasure);
    }

    @Override
    public DataSet emitDataSet(int numberOfSequences, int ... seqLength) throws NotTrainedException, Exception {
        if (seqLength != null && seqLength.length > 0) {
            throw new IllegalArgumentException("You cannot set sequence lengths for a model of a fixed length of " + this.length + ".");
        }
        if (!this.isInitialized()) {
            throw new NotTrainedException();
        }
        this.precomputeNormalization();
        Sequence[] seqs = new IntSequence[numberOfSequences];
        for (int i = 0; i < numberOfSequences; ++i) {
            int[] content = new int[this.length];
            for (int j = 0; j < this.order.length; ++j) {
                this.trees[this.order[j][0]].emitSymbol(content);
            }
            seqs[i] = new IntSequence(this.alphabets, content);
        }
        return new DataSet("", seqs);
    }
}

