/*
 * Decompiled with CFR 0.152.
 */
package projects.tals;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;

public class TAL_M_NSF2
extends AbstractDifferentiableStatisticalModel {
    private double ess;
    private boolean isInitialized = false;
    private double[][] params;
    private double[][] probs;
    private double[] HyperParams;
    private double[] HyperSum;
    private boolean[][] ismatch;
    private int numBack;
    private int numFwd;

    public TAL_M_NSF2(AlphabetContainer alphabetsRVD, int length, double ess, boolean[][] ismatch, int numBack, int numFwd) throws Exception {
        super(alphabetsRVD, 1);
        this.ismatch = (boolean[][])ArrayHandler.clone((Cloneable[])ismatch);
        this.numBack = numBack;
        this.numFwd = numFwd;
        if (!(alphabetsRVD.getAlphabetLengthAt(0) > 0.0)) {
            throw new Exception("Alphabet ist nicht OK!");
        }
        this.ess = ess;
        this.HyperParams = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
        this.HyperSum = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
        Arrays.fill(this.HyperParams, ess * (double)length / (alphabetsRVD.getAlphabetLengthAt(0) * 4.0));
        Arrays.fill(this.HyperSum, ess * (double)length / (alphabetsRVD.getAlphabetLengthAt(0) * 2.0));
        this.params = new double[2][(int)alphabetsRVD.getAlphabetLengthAt(0)];
        this.probs = new double[2][(int)alphabetsRVD.getAlphabetLengthAt(0)];
    }

    public TAL_M_NSF2(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public TAL_M_NSF2 clone() throws CloneNotSupportedException {
        TAL_M_NSF2 clone = (TAL_M_NSF2)super.clone();
        clone.params = (double[][])this.params.clone();
        clone.probs = (double[][])this.probs.clone();
        clone.HyperSum = (double[])this.HyperSum.clone();
        clone.HyperParams = (double[])this.HyperParams.clone();
        return clone;
    }

    public int ismatch(Sequence seq, Sequence refSeq, int pos) {
        if (this.numBack == 0) {
            return 1;
        }
        int i = Math.max(1, pos - this.numBack);
        while (i < pos) {
            if (this.ismatch[refSeq.discreteVal(i - 1)][seq.discreteVal(i)]) {
                return 1;
            }
            ++i;
        }
        i = pos + 1;
        while (i < Math.min(seq.getLength(), pos + this.numFwd + 1)) {
            if (this.ismatch[refSeq.discreteVal(i - 1)][seq.discreteVal(i)]) {
                return 1;
            }
            ++i;
        }
        return 0;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i = 0;
        while (i < this.params.length) {
            int j = 0;
            while (j < this.params[i].length) {
                int n = start++;
                grad[n] = grad[n] + (this.HyperParams[j] - this.HyperSum[j] * this.probs[i][j]);
                ++j;
            }
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        double[] logBeta = new double[this.params[0].length];
        double logPrior = 0.0;
        int i = 0;
        while (i < this.params.length) {
            int j = 0;
            while (j < this.params[i].length) {
                logBeta[j] = Gamma.logOfGamma(this.HyperParams[j]) + Gamma.logOfGamma(this.HyperSum[j] - this.HyperParams[j]) - Gamma.logOfGamma(this.HyperSum[j]);
                logPrior += this.HyperParams[j] * Math.log(this.probs[i][j]) + (this.HyperSum[j] - this.HyperParams[j]) * Math.log1p(-this.probs[i][j]);
                logPrior -= logBeta[j];
                ++j;
            }
            ++i;
        }
        return logPrior;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] pars = new double[this.getNumberOfParameters()];
        System.arraycopy(this.params[0], 0, pars, 0, this.params[0].length);
        System.arraycopy(this.params[1], 0, pars, this.params[0].length, this.params[1].length);
        return pars;
    }

    @Override
    public String getInstanceName() {
        return "TAL_M_NSF2";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        double erg = 0.0;
        int im = this.ismatch(seq, rvd_seq, start);
        erg = Math.log(this.probs[im][rvd_seq.discreteVal(start - 1)]);
        return erg;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        double logScore = 0.0;
        int index_rvd = rvd_seq.discreteVal(start - 1);
        int im = this.ismatch(seq, rvd_seq, start);
        logScore = Math.log(this.probs[im][index_rvd]);
        indices.add(index_rvd + im * this.params[0].length);
        partialDer.add(1.0 - this.probs[im][index_rvd]);
        return logScore;
    }

    @Override
    public int getNumberOfParameters() {
        return this.params[0].length * 2;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int i = 0;
        while (i < this.params.length) {
            int c = 0;
            while (c < this.params[i].length) {
                double temp = Math.random();
                this.params[i][c] = Math.log(temp);
                this.probs[i][c] = temp / (temp + 1.0);
                ++c;
            }
            ++i;
        }
        this.isInitialized = true;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int i = 0;
        while (i < this.params.length) {
            int j = 0;
            while (j < this.params[i].length) {
                this.params[i][j] = params[start];
                this.probs[i][j] = Math.exp(this.params[i][j]) / (1.0 + Math.exp(this.params[i][j]));
                ++j;
                ++start;
            }
            ++i;
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(xml, this.length, "length");
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.appendObjectWithTags(xml, this.HyperParams, "hyperParams");
        XMLParser.appendObjectWithTags(xml, this.HyperSum, "hyperSum");
        XMLParser.appendObjectWithTags(xml, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(xml, this.params, "params");
        XMLParser.appendObjectWithTags(xml, this.probs, "probs");
        XMLParser.appendObjectWithTags(xml, this.ismatch, "ismatch");
        XMLParser.appendObjectWithTags(xml, this.numBack, "numBack");
        XMLParser.appendObjectWithTags(xml, this.numFwd, "numFwd");
        XMLParser.addTags(xml, "TALMSF");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "TALMSF");
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabets");
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.HyperParams = (double[])XMLParser.extractObjectForTags(xml, "hyperParams");
        this.HyperSum = (double[])XMLParser.extractObjectForTags(xml, "hyperSum");
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.params = (double[][])XMLParser.extractObjectForTags(xml, "params");
        this.probs = (double[][])XMLParser.extractObjectForTags(xml, "probs");
        this.ismatch = (boolean[][])XMLParser.extractObjectForTags(xml, "ismatch");
        this.numBack = XMLParser.extractObjectForTags(xml, "numBack", Integer.TYPE);
        this.numFwd = XMLParser.extractObjectForTags(xml, "numFwd", Integer.TYPE);
    }

    @Override
    public String toString() {
        throw new Error("Unresolved compilation problem: \n\tCannot override the final method from AbstractDifferentiableStatisticalModel\n");
    }

    @Override
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type TAL_M_NSF2 must implement the inherited abstract method SequenceScore.toString(NumberFormat)\n");
    }
}

