/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.ConstraintManager;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.SequenceIterator;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import java.util.ArrayList;
import java.util.Arrays;

public final class MarkovRandomFieldDiffSM
extends AbstractDifferentiableStatisticalModel {
    private MEMConstraint[] constr;
    private String name;
    private boolean freeParams;
    private int[] offset;
    private int[] help;
    private double ess;
    private double norm;
    private double[][] partNorm;
    private SequenceIterator seqIt;
    private static final String XML_TAG = "MarkovRandomFieldDiffSM";

    public MarkovRandomFieldDiffSM(AlphabetContainer alphabets, int length, String constr) {
        this(alphabets, length, 0.0, constr);
    }

    public MarkovRandomFieldDiffSM(AlphabetContainer alphabets, int length, double ess, String constr) {
        super(alphabets, length);
        if (!alphabets.isDiscrete()) {
            throw new IllegalArgumentException("The AlphabetContainer has to be discrete.");
        }
        if (ess < 0.0) {
            throw new IllegalArgumentException("The ess has to be non-negative.");
        }
        this.ess = ess;
        int[] aLength = new int[length];
        int i = 0;
        while (i < length) {
            aLength[i] = (int)alphabets.getAlphabetLengthAt(i++);
        }
        ArrayList<int[]> list = ConstraintManager.extract(length, constr);
        ConstraintManager.reduce(list);
        this.constr = ConstraintManager.createConstraints(list, aLength);
        this.name = constr;
        this.freeParams = false;
        this.getNumberOfParameters();
        this.init(Double.NaN);
    }

    public MarkovRandomFieldDiffSM(StringBuffer source) throws NonParsableException {
        super(source);
    }

    private void init(double n) {
        this.norm = n;
        if (this.partNorm == null) {
            this.partNorm = new double[this.constr.length][];
            int i = 0;
            while (i < this.partNorm.length) {
                this.partNorm[i] = new double[this.constr[i].getNumberOfSpecificConstraints()];
                ++i;
            }
            this.help = new int[2];
            int[] aLength = new int[this.length];
            int i2 = 0;
            while (i2 < this.length) {
                aLength[i2] = (int)this.alphabets.getAlphabetLengthAt(i2);
                ++i2;
            }
            this.seqIt = new SequenceIterator(this.length);
            this.seqIt.setBounds(aLength);
        } else {
            int i = 0;
            while (i < this.partNorm.length) {
                Arrays.fill(this.partNorm[i], n);
                ++i;
            }
        }
    }

    @Override
    protected void fromXML(StringBuffer representation) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(representation, XML_TAG);
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabets");
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.name = XMLParser.extractObjectForTags(xml, "name", String.class);
        this.constr = XMLParser.extractObjectForTags(xml, "constr", MEMConstraint[].class);
        this.freeParams = XMLParser.extractObjectForTags(xml, "freeParams", Boolean.TYPE);
        this.getNumberOfParameters();
        this.init(Double.NaN);
    }

    @Override
    public MarkovRandomFieldDiffSM clone() throws CloneNotSupportedException {
        MarkovRandomFieldDiffSM clone = (MarkovRandomFieldDiffSM)super.clone();
        clone.constr = (MEMConstraint[])ArrayHandler.clone((Cloneable[])this.constr);
        clone.partNorm = new double[this.partNorm.length][];
        int i = 0;
        while (i < this.partNorm.length) {
            clone.partNorm[i] = (double[])this.partNorm[i].clone();
            ++i;
        }
        clone.norm = this.norm;
        clone.help = (int[])this.help.clone();
        clone.offset = null;
        clone.getNumberOfParameters();
        clone.seqIt = this.seqIt.clone();
        return clone;
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double erg = 0.0;
        int i = 0;
        while (i < this.constr.length) {
            erg += this.constr[i].getLambda(this.constr[i].satisfiesSpecificConstraint(seq, start));
            ++i;
        }
        return erg;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double erg = 0.0;
        int i = 0;
        while (i < this.constr.length) {
            int j = this.constr[i].satisfiesSpecificConstraint(seq, start);
            int z = this.offset[i] + j;
            if (z < this.offset[i + 1]) {
                indices.add(z);
                partialDer.add(1.0);
            }
            erg += this.constr[i].getLambda(j);
            ++i;
        }
        return erg;
    }

    @Override
    public int getNumberOfParameters() {
        if (this.offset == null) {
            int i = 0;
            int anz = 0;
            this.offset = new int[this.constr.length + 1];
            while (i < this.constr.length) {
                anz += this.constr[i++].getNumberOfSpecificConstraints();
                if (this.freeParams) {
                    // empty if block
                }
                this.offset[i] = --anz;
            }
        }
        return this.offset[this.constr.length];
    }

    @Override
    public String getInstanceName() {
        return "MRF(" + this.name + ")";
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.norm = Double.NaN;
        int i = 0;
        int s = this.offset[0];
        while (i < this.constr.length) {
            int j = 0;
            while (s < this.offset[i + 1]) {
                this.constr[i].setLambda(j, params[start + s]);
                ++s;
                ++j;
            }
            ++i;
        }
    }

    public String toString() {
        int i = 0;
        StringBuffer res = new StringBuffer();
        res.append(String.valueOf(this.getInstanceName()) + "\n");
        while (i < this.constr.length) {
            res.append(this.constr[i]);
            int j = 0;
            while (j < this.constr[i].getNumberOfSpecificConstraints()) {
                res.append("\t" + this.constr[i].getLambda(j));
                ++j;
            }
            res.append("\n");
            ++i;
        }
        return res.toString();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer b = new StringBuffer(10000);
        XMLParser.appendObjectWithTags(b, this.length, "length");
        XMLParser.appendObjectWithTags(b, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(b, this.ess, "ess");
        XMLParser.appendObjectWithTags(b, this.name, "name");
        XMLParser.appendObjectWithTags(b, this.constr, "constr");
        XMLParser.appendObjectWithTags(b, this.freeParams, "freeParams");
        XMLParser.addTags(b, XML_TAG);
        return b;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (this.freeParams != freeParams) {
            this.offset = null;
            this.freeParams = freeParams;
            this.getNumberOfParameters();
        }
        double d = 0.0;
        int i = 0;
        while (i < this.length) {
            d -= Math.log(this.alphabets.getAlphabetLengthAt(i));
            ++i;
        }
        d /= (double)this.constr.length;
        int i2 = 0;
        while (i2 < this.constr.length) {
            int k = this.constr[i2].getNumberOfSpecificConstraints();
            int j = 0;
            while (j < k) {
                this.constr[i2].setLambda(j, d);
                ++j;
            }
            ++i2;
        }
        this.norm = Double.NaN;
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        if (this.freeParams != freeParams) {
            this.offset = null;
            this.freeParams = freeParams;
            this.getNumberOfParameters();
        }
        int i = 0;
        while (i < this.constr.length) {
            int k = this.constr[i].getNumberOfSpecificConstraints();
            int j = 0;
            while (j < k) {
                this.constr[i].setLambda(j, r.nextGaussian() / (double)k);
                ++j;
            }
            ++i;
        }
        this.norm = Double.NaN;
    }

    @Override
    public double getLogNormalizationConstant() {
        if (Double.isNaN(this.norm)) {
            this.precompute();
        }
        return this.norm;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        if (Double.isNaN(this.norm)) {
            this.precompute();
        }
        this.computeIndices(parameterIndex);
        return this.partNorm[this.help[0]][this.help[1]];
    }

    private void precompute() {
        this.seqIt.reset();
        int[] fulfilled = new int[this.constr.length];
        this.seqIt.reset();
        this.init(Double.NEGATIVE_INFINITY);
        do {
            double s = this.getLogScore(fulfilled, this.seqIt);
            int i = 0;
            while (i < this.constr.length) {
                this.partNorm[i][fulfilled[i]] = Normalisation.getLogSum(this.partNorm[i][fulfilled[i]], s);
                ++i;
            }
            this.norm = Normalisation.getLogSum(this.norm, s);
        } while (this.seqIt.next());
    }

    private double getLogScore(int[] fulfilled, SequenceIterator sequence) {
        double s = 0.0;
        int counter = 0;
        while (counter < this.constr.length) {
            fulfilled[counter] = this.constr[counter].satisfiesSpecificConstraint(sequence);
            s += this.constr[counter].getLambda(fulfilled[counter]);
            ++counter;
        }
        return s;
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        this.computeIndices(index);
        return this.constr[this.help[0]].getNumberOfSpecificConstraints();
    }

    private void computeIndices(int index) {
        this.help[0] = 0;
        while (index >= this.offset[this.help[0]]) {
            this.help[0] = this.help[0] + 1;
        }
        this.help[0] = this.help[0] - 1;
        this.help[1] = index - this.offset[this.help[0]];
    }

    @Override
    public double getLogPriorTerm() {
        double logPriorTerm = 0.0;
        int i = 0;
        int s = 0;
        while (i < this.constr.length) {
            s = this.constr[i].getNumberOfSpecificConstraints();
            double d = this.ess / (double)s;
            int j = 0;
            while (j < s) {
                logPriorTerm += this.constr[i].getLambda(j) * d;
                ++j;
            }
            ++i;
        }
        return logPriorTerm;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        int i = 0;
        int s = this.offset[0];
        while (i < this.constr.length) {
            double d = this.ess / (double)this.constr[i].getNumberOfSpecificConstraints();
            ++i;
            while (s < this.offset[i]) {
                int n = start++;
                grad[n] = grad[n] + d;
                ++s;
            }
        }
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] start = new double[this.offset[this.constr.length]];
        int i = 0;
        int n = 0;
        while (i < this.constr.length) {
            int j = 0;
            while (n < this.offset[i + 1]) {
                start[n] = this.constr[i].getLambda(j);
                ++n;
                ++j;
            }
            ++i;
        }
        return start;
    }

    @Override
    public boolean isInitialized() {
        return true;
    }
}

