/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.motifDiscovery.MotifDiscoverer;
import de.jstacs.motifDiscovery.MutableMotifDiscoverer;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;

public class MixtureDiffSM
extends AbstractMixtureDiffSM
implements MutableMotifDiscoverer {
    private int[] motifsRef;

    public MixtureDiffSM(int starts, boolean plugIn, DifferentiableStatisticalModel ... component) throws CloneNotSupportedException {
        super(component[0].getLength(), starts, component.length, true, plugIn, component);
        for (int i = 0; i < component.length; ++i) {
            int l = component[i].getLength();
            if (l != 0 && this.length != l) {
                throw new IllegalArgumentException("The length of component " + i + " is " + l + " but should be " + this.length + ".");
            }
            if (this.alphabets.checkConsistency(component[i].getAlphabetContainer())) continue;
            throw new IllegalArgumentException("The AlphabetContainer of component " + i + " is not suitable.");
        }
        this.computeLogGammaSum();
        this.init();
    }

    private void init() {
        this.motifsRef = new int[this.function.length + 1];
        for (int i = 0; i < this.function.length; ++i) {
            this.motifsRef[i + 1] = this.motifsRef[i];
            if (!(this.function[i] instanceof MotifDiscoverer)) continue;
            int n = i + 1;
            this.motifsRef[n] = this.motifsRef[n] + ((MotifDiscoverer)((Object)this.function[i])).getNumberOfMotifs();
        }
    }

    public MixtureDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.init();
    }

    @Override
    public MixtureDiffSM clone() throws CloneNotSupportedException {
        MixtureDiffSM clone = (MixtureDiffSM)super.clone();
        clone.init();
        return clone;
    }

    @Override
    protected double getLogNormalizationConstantForComponent(int i) {
        return this.function[i].getLogNormalizationConstant();
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        int[] ind;
        if (this.isNormalized()) {
            return Double.NEGATIVE_INFINITY;
        }
        if (Double.isNaN(this.norm)) {
            this.precomputeNorm();
        }
        if ((ind = this.getIndices(parameterIndex))[0] == this.function.length) {
            return this.partNorm[ind[1]];
        }
        return this.logHiddenPotential[ind[0]] + this.function[ind[0]].getLogPartialNormalizationConstant(ind[1]);
    }

    @Override
    public double getHyperparameterForHiddenParameter(int index) {
        return this.function[index].getESS();
    }

    @Override
    public double getESS() {
        double ess = 0.0;
        for (int i = 0; i < this.function.length; ++i) {
            ess += this.function[i].getESS();
        }
        return ess;
    }

    private double[][] getRandomWeights(double[] originalWeights, int len) {
        int j;
        Arrays.fill(this.hiddenParameter, 0.0);
        double[][] newWeights = new double[this.function.length][len];
        int i = 0;
        double[] h = new double[this.getNumberOfComponents()];
        if (this.getESS() == 0.0) {
            Arrays.fill(h, 1.0);
        } else {
            for (j = 0; j < h.length; ++j) {
                h[j] = this.getHyperparameterForHiddenParameter(j);
            }
        }
        DirichletMRGParams param = new DirichletMRGParams(h);
        double[] p = new double[h.length];
        double w = 1.0;
        while (i < newWeights[0].length) {
            DirichletMRG.DEFAULT_INSTANCE.generate(p, 0, p.length, param);
            if (originalWeights != null) {
                w = originalWeights[i];
            }
            for (j = 0; j < p.length; ++j) {
                newWeights[j][i] = w * p[j];
                int n = j;
                this.hiddenParameter[n] = this.hiddenParameter[n] + newWeights[j][i];
            }
            ++i;
        }
        this.computeHiddenParameter(this.hiddenParameter, true);
        return newWeights;
    }

    @Override
    protected void initializeUsingPlugIn(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (weights == null) {
            weights = new double[data.length][];
        }
        double[] help = weights[index];
        double[][] newWeights = null;
        for (int r = 0; r < 3; ++r) {
            if (r == 0) {
                newWeights = this.getRandomWeights(help, data[index].getNumberOfElements());
            } else {
                for (int n = 0; n < data[index].getNumberOfElements(); ++n) {
                    this.fillComponentScores(data[index].getElementAt(n), 0);
                    Normalisation.logSumNormalisation(this.componentScore);
                    for (int k = 0; k < this.function.length; ++k) {
                        newWeights[k][n] = this.componentScore[k] * help[n];
                    }
                }
            }
            for (int i = 0; i < this.function.length; ++i) {
                weights[index] = newWeights[i];
                this.function[i].initializeFunction(index, freeParams, data, (double[][])weights);
            }
        }
        weights[index] = help;
    }

    @Override
    public void adjustHiddenParameters(int index, DataSet[] data, double[][] weights) throws Exception {
        if (weights == null) {
            weights = new double[data.length][];
        }
        double[][] newWeights = this.getRandomWeights(weights[index], data[index].getNumberOfElements());
        double[] help = weights[index];
        for (int i = 0; i < this.function.length; ++i) {
            weights[index] = newWeights[i];
            if (!(this.function[i] instanceof MutableMotifDiscoverer)) continue;
            ((MutableMotifDiscoverer)((Object)this.function[i])).adjustHiddenParameters(index, data, (double[][])weights);
        }
        weights[index] = help;
    }

    @Override
    public String getInstanceName() {
        String erg = "mixture(" + this.function[0].getInstanceName();
        for (int i = 1; i < this.function.length; ++i) {
            erg = erg + ", " + this.function[i].getInstanceName();
        }
        return erg + ")";
    }

    @Override
    protected void fillComponentScores(Sequence seq, int start) {
        for (int i = 0; i < this.function.length; ++i) {
            if (this.function[i] instanceof VariableLengthDiffSM) {
                if (this.length != 0) {
                    this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreFor(seq, start, start + this.length - 1);
                    continue;
                }
                this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreFor(seq, start, seq.getLength() - 1);
                continue;
            }
            this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScoreFor(seq, start);
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        int i;
        int j = 0;
        int k = this.paramRef.length - 1;
        k = this.paramRef[k] - this.paramRef[k - 1];
        for (i = 0; i < this.function.length; ++i) {
            this.iList[i].clear();
            this.dList[i].clear();
            if (this.function[i] instanceof VariableLengthDiffSM) {
                if (this.length != 0) {
                    this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreAndPartialDerivation(seq, start, start + this.length - 1, this.iList[i], this.dList[i]);
                    continue;
                }
                this.componentScore[i] = this.logHiddenPotential[i] + ((VariableLengthDiffSM)this.function[i]).getLogScoreAndPartialDerivation(seq, start, seq.getLength() - 1, this.iList[i], this.dList[i]);
                continue;
            }
            this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScoreAndPartialDerivation(seq, start, this.iList[i], this.dList[i]);
        }
        double logScore = Normalisation.logSumNormalisation(this.componentScore, 0, this.function.length, this.componentScore, 0);
        for (i = 0; i < this.function.length; ++i) {
            for (j = 0; j < this.iList[i].length(); ++j) {
                indices.add(this.paramRef[i] + this.iList[i].get(j));
                partialDer.add(this.componentScore[i] * this.dList[i].get(j));
            }
        }
        for (j = 0; j < k; ++j) {
            indices.add(this.paramRef[i] + j);
            partialDer.add(this.componentScore[j] - (this.isNormalized() ? this.hiddenPotential[j] : 0.0));
        }
        return logScore;
    }

    public String toString() {
        if (Double.isNaN(this.norm)) {
            this.precomputeNorm();
        }
        StringBuffer erg = new StringBuffer(this.function.length * 1000);
        for (int i = 0; i < this.function.length; ++i) {
            erg.append("p(" + i + ") = " + (this.isNormalized() ? this.hiddenPotential[i] : Math.exp(this.partNorm[i] - this.norm)) + "\n" + this.function[i].toString() + "\n");
        }
        return erg.toString();
    }

    private int getComponentFor(int motif) {
        int res = 0;
        while (motif >= this.motifsRef[res]) {
            ++res;
        }
        return res - 1;
    }

    @Override
    public void initializeMotif(int motifIndex, DataSet data, double[] weights) throws Exception {
        int c = this.getComponentFor(motifIndex);
        if (this.function[c] instanceof MutableMotifDiscoverer) {
            ((MutableMotifDiscoverer)((Object)this.function[c])).initializeMotif(motifIndex - this.motifsRef[c], data, weights);
        } else {
            System.out.println("WARNING: Not possible!");
        }
    }

    @Override
    public void initializeMotifRandomly(int motif) throws Exception {
        int c = this.getComponentFor(motif);
        if (this.function[c] instanceof MutableMotifDiscoverer) {
            ((MutableMotifDiscoverer)((Object)this.function[c])).initializeMotifRandomly(motif - this.motifsRef[c]);
        } else {
            System.out.println("WARNING: Not possible!");
        }
    }

    @Override
    public boolean modifyMotif(int motifIndex, int offsetLeft, int offsetRight) throws Exception {
        int c = this.getComponentFor(motifIndex);
        if (this.function[c] instanceof MutableMotifDiscoverer) {
            boolean b = ((MutableMotifDiscoverer)((Object)this.function[c])).modifyMotif(motifIndex - this.motifsRef[c], offsetLeft, offsetRight);
            if (b) {
                this.init(this.freeParams);
            }
            return b;
        }
        return false;
    }

    @Override
    public int getGlobalIndexOfMotifInComponent(int component, int motif) {
        int res = this.motifsRef[component] + motif;
        if (res >= this.motifsRef[component + 1]) {
            throw new IndexOutOfBoundsException("Component " + component + " has only " + (this.motifsRef[component + 1] - this.motifsRef[component]) + " motifs.");
        }
        return res;
    }

    @Override
    public int getIndexOfMaximalComponentFor(Sequence sequence) throws Exception {
        return this.getIndexOfMaximalComponentFor(sequence, 0);
    }

    @Override
    public int getMotifLength(int motif) {
        int c = this.getComponentFor(motif);
        return ((MotifDiscoverer)((Object)this.function[c])).getMotifLength(motif - this.motifsRef[c]);
    }

    @Override
    public int getNumberOfMotifs() {
        return this.motifsRef[this.function.length];
    }

    @Override
    public int getNumberOfMotifsInComponent(int component) {
        return this.motifsRef[component + 1] - this.motifsRef[component];
    }

    @Override
    public double[] getProfileOfScoresFor(int component, int motif, Sequence sequence, int startpos, MotifDiscoverer.KindOfProfile kind) throws Exception {
        if (kind == MotifDiscoverer.KindOfProfile.UNNORMALIZED_JOINT) {
            if (this.function[component] instanceof MotifDiscoverer) {
                MotifDiscoverer md = (MotifDiscoverer)((Object)this.function[component]);
                double[] prof = null;
                int c = md.getNumberOfComponents();
                for (int i = 0; i < c; ++i) {
                    int m;
                    int n = md.getNumberOfMotifsInComponent(i);
                    for (m = 0; m < n && md.getGlobalIndexOfMotifInComponent(i, m) != motif; ++m) {
                    }
                    if (m >= n) continue;
                    double[] current = md.getProfileOfScoresFor(i, motif, sequence, startpos, kind);
                    if (prof == null) {
                        prof = current;
                        continue;
                    }
                    for (int p = 0; p < prof.length; ++p) {
                        prof[p] = Normalisation.getLogSum(prof[p], current[p]);
                    }
                }
                if (prof == null) {
                    throw new IllegalArgumentException();
                }
                return prof;
            }
            throw new IllegalArgumentException();
        }
        throw new OperationNotSupportedException("Currently it is only allowed to used KindOfProfile.UNNORMALIZED_JOINT");
    }

    @Override
    public double[] getStrandProbabilitiesFor(int component, int motif, Sequence sequence, int startpos) throws Exception {
        return new double[]{0.5, 0.5};
    }
}

