/*
 * Decompiled with CFR 0.152.
 */
package projects.tals.linear;

import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;
import java.util.Arrays;
import projects.tals.RVDSequence;
import projects.tals.linear.LF0Conditional;
import projects.tals.linear.LFPosition_mixture;
import projects.tals.linear.LFSpecificity_parallel_cond9C;

public class LFModularConditional9C
extends AbstractDifferentiableSequenceScore {
    protected LFPosition_mixture position;
    protected LFSpecificity_parallel_cond9C specificity;
    protected LF0Conditional lf0;
    protected double[] a;
    protected double[] b;

    public LFModularConditional9C(LF0Conditional lf0, LFSpecificity_parallel_cond9C specificity, LFPosition_mixture lfPosition, int numberOfGroups) throws IllegalArgumentException {
        super(DNAAlphabetContainer.SINGLETON, 0);
        this.lf0 = lf0;
        this.specificity = specificity;
        this.position = lfPosition;
        this.a = new double[numberOfGroups];
        this.b = new double[numberOfGroups];
    }

    public DifferentiableSequenceScore getSpecModel() {
        return this.specificity;
    }

    public LFModularConditional9C(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public LFModularConditional9C clone() throws CloneNotSupportedException {
        LFModularConditional9C clone = (LFModularConditional9C)super.clone();
        clone.lf0 = this.lf0.clone();
        clone.specificity = this.specificity.clone();
        clone.a = (double[])this.a.clone();
        clone.b = (double[])this.b.clone();
        if (this.position != null) {
            clone.position = (LFPosition_mixture)this.position.clone();
        }
        return clone;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.lf0.initializeFunctionRandomly(freeParams);
        this.specificity.initializeFunctionRandomly(freeParams);
        if (this.position != null) {
            this.position.initializeFunctionRandomly(freeParams);
        }
        int i = 0;
        while (i < this.a.length) {
            this.a[i] = Math.log(Math.random() + 0.5);
            this.b[i] = Math.random() - 0.5;
            ++i;
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        String gs = seq.getSequenceAnnotationByType("intgroup", 0).getIdentifier();
        String mask = null;
        SequenceAnnotation mann = seq.getSequenceAnnotationByType("mask", 0);
        if (mann != null) {
            mask = mann.getIdentifier();
        }
        int group = Integer.parseInt(gs);
        double ag = Math.exp(this.a[group]);
        int off = indices.length();
        double score = 0.0;
        if (mask == null || mask.charAt(start) == 'O') {
            score += this.lf0.getLogScoreAndPartialDerivation(seq, start, indices, partialDer);
        }
        partialDer.multiply(off, partialDer.length(), ag);
        indices.addToValues(off, indices.length(), this.a.length + this.b.length);
        off = indices.length();
        int i = start + 1;
        while (i < seq.getLength()) {
            if (mask == null || mask.charAt(i) == 'O') {
                int cs = partialDer.length();
                double spec = this.specificity.getLogScoreAndPartialDerivation(seq, i, indices, partialDer);
                int ci = partialDer.length();
                double pos = 1.0;
                int cp = partialDer.length();
                if (this.position != null) {
                    pos = this.position.getLogScoreAndPartialDerivation(seq, i, indices, partialDer);
                }
                partialDer.multiply(cs, ci, pos * ag);
                indices.addToValues(cs, ci, this.a.length + this.b.length + this.lf0.getNumberOfParameters());
                partialDer.multiply(ci, cp, spec * pos * ag);
                indices.addToValues(ci, cp, this.a.length + this.b.length + this.lf0.getNumberOfParameters() + this.specificity.getNumberOfParameters());
                partialDer.multiply(cp, partialDer.length(), spec * ag);
                indices.addToValues(cp, indices.length(), this.a.length + this.b.length + this.lf0.getNumberOfParameters() + this.specificity.getNumberOfParameters());
                score += spec * pos;
            }
            ++i;
        }
        indices.add(group);
        partialDer.add(score * ag);
        indices.add(this.a.length + group);
        partialDer.add(1.0);
        score *= ag;
        return score += this.b[group];
    }

    @Override
    public int getNumberOfParameters() {
        return this.a.length + this.b.length + this.lf0.getNumberOfParameters() + this.specificity.getNumberOfParameters() + (this.position == null ? 0 : this.position.getNumberOfParameters());
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] params = new double[this.getNumberOfParameters()];
        System.arraycopy(this.a, 0, params, 0, this.a.length);
        System.arraycopy(this.b, 0, params, this.a.length, this.b.length);
        int off = this.a.length + this.b.length;
        System.arraycopy(this.lf0.getCurrentParameterValues(), 0, params, off, this.lf0.getNumberOfParameters());
        System.arraycopy(this.specificity.getCurrentParameterValues(), 0, params, off += this.lf0.getNumberOfParameters(), this.specificity.getNumberOfParameters());
        off += this.specificity.getNumberOfParameters();
        if (this.position != null) {
            System.arraycopy(this.position.getCurrentParameterValues(), 0, params, off, this.position.getNumberOfParameters());
        }
        return params;
    }

    @Override
    public void setParameters(double[] params, int start) {
        System.arraycopy(params, start, this.a, 0, this.a.length);
        System.arraycopy(params, start += this.a.length, this.b, 0, this.b.length);
        this.lf0.setParameters(params, start += this.b.length);
        this.specificity.setParameters(params, start += this.lf0.getNumberOfParameters());
        start += this.specificity.getNumberOfParameters();
        if (this.position != null) {
            this.position.setParameters(params, start);
        }
    }

    @Override
    public String getInstanceName() {
        return "LFModular";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        String mask = null;
        SequenceAnnotation mann = seq.getSequenceAnnotationByType("mask", 0);
        if (mann != null) {
            mask = mann.getIdentifier();
        }
        SequenceAnnotation ann = seq.getSequenceAnnotationByType("intgroup", 0);
        int group = 0;
        if (ann != null) {
            String gs = ann.getIdentifier();
            group = Integer.parseInt(gs);
        }
        double score = 0.0;
        if (mask == null || mask.charAt(start) == 'O') {
            score += this.lf0.getLogScoreFor(seq, start);
        }
        int i = start + 1;
        while (i < seq.getLength()) {
            if (mask == null || mask.charAt(i) == 'O') {
                double spec = this.specificity.getLogScoreFor(seq, i);
                double pos = this.position == null ? 1.0 : this.position.getLogScoreFor(seq, i);
                score += spec * pos;
            }
            ++i;
        }
        score *= Math.exp(this.a[group]);
        return score += this.b[group];
    }

    public String getLogScorePoswiseFor(Sequence seq, int start) {
        String outScore = "";
        String mask = null;
        SequenceAnnotation mann = seq.getSequenceAnnotationByType("mask", 0);
        if (mann != null) {
            mask = mann.getIdentifier();
        }
        SequenceAnnotation ann = seq.getSequenceAnnotationByType("intgroup", 0);
        int group = 0;
        if (ann != null) {
            String gs = ann.getIdentifier();
            group = Integer.parseInt(gs);
        }
        double score = 0.0;
        if (mask == null || mask.charAt(start) == 'O') {
            outScore = "lf0: " + (score += this.lf0.getLogScoreFor(seq, start));
        }
        int i = start + 1;
        while (i < seq.getLength()) {
            if (mask == null || mask.charAt(i) == 'O') {
                double spec = this.specificity.getLogScoreFor(seq, i);
                double pos = this.position == null ? 1.0 : this.position.getLogScoreFor(seq, i);
                score += spec * pos;
                outScore = String.valueOf(outScore) + ", spec_pos" + i + ": " + spec;
            }
            ++i;
        }
        score *= Math.exp(this.a[group]);
        score += this.b[group];
        return outScore;
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public String toString(NumberFormat nf) {
        return String.valueOf(Arrays.toString(this.a)) + "\n" + Arrays.toString(this.b) + "\n\n" + this.lf0.toString() + "\n\n" + this.specificity.toString() + "\n\n" + this.position;
    }

    public double[][] toPWM(RVDSequence rvds) {
        double[][] pwm = new double[rvds.getLength() + 1][];
        pwm[0] = this.lf0.getSpecs(rvds);
        int i = 1;
        while (i < pwm.length) {
            pwm[i] = this.specificity.getSpecs(rvds, i);
            double pos = this.position == null ? 1.0 : this.position.getLogScoreFor(rvds.getLength() + 1, i);
            int j = 0;
            while (j < pwm[i].length) {
                double[] dArray = pwm[i];
                int n = j++;
                dArray[n] = dArray[n] * pos;
            }
            ++i;
        }
        i = 0;
        while (i < pwm.length) {
            int j = 0;
            while (j < pwm[i].length) {
                double[] dArray = pwm[i];
                int n = j++;
                dArray[n] = dArray[n] / (double)pwm.length;
            }
            ++i;
        }
        return pwm;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.specificity, "specificity");
        XMLParser.appendObjectWithTags(xml, this.position, "position");
        XMLParser.appendObjectWithTags(xml, this.a, "a");
        XMLParser.appendObjectWithTags(xml, this.b, "b");
        XMLParser.appendObjectWithTags(xml, this.lf0, "lf0");
        XMLParser.addTags(xml, "LFMod");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "LFMod");
        this.specificity = (LFSpecificity_parallel_cond9C)XMLParser.extractObjectForTags(xml, "specificity");
        this.position = (LFPosition_mixture)XMLParser.extractObjectForTags(xml, "position");
        this.a = (double[])XMLParser.extractObjectForTags(xml, "a");
        this.b = (double[])XMLParser.extractObjectForTags(xml, "b");
        this.lf0 = (LF0Conditional)XMLParser.extractObjectForTags(xml, "lf0");
        this.alphabets = DNAAlphabetContainer.SINGLETON;
        this.length = 0;
    }
}

