/*
 * Decompiled with CFR 0.152.
 */
package projects.tals.linear;

import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;

public class LFPosition_mixture
extends AbstractDifferentiableSequenceScore {
    private double beta1;
    private double beta2;
    private double beta3;
    private double a1;
    private double a2;
    private double b1;
    private double b2;
    private double[][] precomputed = new double[50][];

    public LFPosition_mixture() throws IllegalArgumentException {
        super(DNAAlphabetContainer.SINGLETON, 0);
    }

    public LFPosition_mixture(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.beta1 = Math.random() * 2.0 - 1.0;
        this.beta2 = Math.random() * 2.0 - 1.0;
        this.beta3 = Math.random() * 2.0 - 1.0;
        this.a1 = -Math.random() * 3.0;
        this.a2 = Math.random() * 3.0;
        this.b1 = -(Math.random() * 0.5 + 0.25);
        this.b2 = -(Math.random() * 0.5 + 0.25);
        this.precomputed = new double[50][];
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double x = (double)start / (double)seq.getLength();
        double y = (double)(seq.getLength() - start) / (double)seq.getLength();
        double p1 = Math.exp(this.beta1);
        double p2 = Math.exp(this.beta2);
        double p3 = Math.exp(this.beta3);
        double sum = p1 + p2 + p3;
        double f1 = 1.0 / (1.0 + Math.exp(-this.a1 * (x + this.b1)));
        double f2 = 1.0 / (1.0 + Math.exp(-this.a2 * (y + this.b2)));
        double score = (p1 /= sum) * f1 + (p2 /= sum) * f2 + (p3 /= sum) * 1.0;
        indices.add(0);
        partialDer.add(p1 * (1.0 - p1) * f1 - p1 * p2 * f2 - p1 * p3);
        indices.add(1);
        partialDer.add(-p1 * p2 * f1 + p2 * (1.0 - p2) * f2 - p2 * p3);
        indices.add(2);
        partialDer.add(-p1 * p3 * f1 - p2 * p3 * f2 + p3 * (1.0 - p3));
        indices.add(3);
        partialDer.add(p1 * f1 * (1.0 - f1) * (x + this.b1));
        indices.add(4);
        partialDer.add(p2 * f2 * (1.0 - f2) * (y + this.b2));
        indices.add(5);
        partialDer.add(p1 * f1 * (1.0 - f1) * this.a1);
        indices.add(6);
        partialDer.add(p2 * f2 * (1.0 - f2) * this.a2);
        return score;
    }

    @Override
    public int getNumberOfParameters() {
        return 7;
    }

    @Override
    public double[] getCurrentParameterValues() {
        return new double[]{this.beta1, this.beta2, this.beta3, this.a1, this.a2, this.b1, this.b2};
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.beta1 = params[start];
        this.beta2 = params[start + 1];
        this.beta3 = params[start + 2];
        this.a1 = params[start + 3];
        this.a2 = params[start + 4];
        this.b1 = params[start + 5];
        this.b2 = params[start + 6];
        this.precomputed = new double[50][];
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        if (this.precomputed[seq.getLength()] == null) {
            this.precomputed[seq.getLength()] = new double[seq.getLength()];
            int start2 = 0;
            while (start2 < seq.getLength()) {
                double x = (double)start2 / (double)seq.getLength();
                double y = (double)(seq.getLength() - start2) / (double)seq.getLength();
                double p1 = Math.exp(this.beta1);
                double p2 = Math.exp(this.beta2);
                double p3 = Math.exp(this.beta3);
                double sum = p1 + p2 + p3;
                double f1 = 1.0 / (1.0 + Math.exp(-this.a1 * (x + this.b1)));
                double f2 = 1.0 / (1.0 + Math.exp(-this.a2 * (y + this.b2)));
                this.precomputed[seq.getLength()][start2] = (p1 /= sum) * f1 + (p2 /= sum) * f2 + (p3 /= sum) * 1.0;
                ++start2;
            }
        }
        return this.precomputed[seq.getLength()][start];
    }

    public double getLogScoreFor(int len, int start) {
        if (this.precomputed[len] == null) {
            this.precomputed[len] = new double[len];
            int start2 = 0;
            while (start2 < len) {
                double x = (double)start2 / (double)len;
                double y = (double)(len - start2) / (double)len;
                double p1 = Math.exp(this.beta1);
                double p2 = Math.exp(this.beta2);
                double p3 = Math.exp(this.beta3);
                double sum = p1 + p2 + p3;
                double f1 = 1.0 / (1.0 + Math.exp(-this.a1 * (x + this.b1)));
                double f2 = 1.0 / (1.0 + Math.exp(-this.a2 * (y + this.b2)));
                this.precomputed[len][start2] = (p1 /= sum) * f1 + (p2 /= sum) * f2 + (p3 /= sum) * 1.0;
                ++start2;
            }
        }
        return this.precomputed[len][start];
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public String toString(NumberFormat nf) {
        double p1 = Math.exp(this.beta1);
        double p2 = Math.exp(this.beta2);
        double p3 = Math.exp(this.beta3);
        double sum = p1 + p2 + p3;
        return String.valueOf(p1 /= sum) + "*1/(1+exp(" + -this.a1 + "*(x + " + this.b1 + "))) + " + (p2 /= sum) + "*1/(1+exp(" + -this.a2 + "*(y + " + this.b2 + "))) + " + (p3 /= sum);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.getCurrentParameterValues(), "params");
        XMLParser.addTags(xml, "mixture");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "mixture");
        this.setParameters((double[])XMLParser.extractObjectForTags(xml, "params"), 0);
        this.alphabets = DNAAlphabetContainer.SINGLETON;
        this.length = 0;
    }
}

