package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import java.util.Arrays;
import org.apache.batik.util.SVGConstants;
import org.apache.xmlgraphics.image.loader.spi.ImagePreloader;
import org.biojavax.bio.seq.io.RichSequenceBuilderFactory;

/* loaded from: input_file:de/jstacs/classifiers/differentiableSequenceScoreBased/gendismix/GenDisMixClassifier.class */
public class GenDisMixClassifier extends ScoreClassifier {
    protected LogPrior prior;
    protected LogGenDisMixFunction function;
    protected double[] beta;
    private static final String XML_TAG = "gendismix-classifier";

    /* JADX INFO: Access modifiers changed from: protected */
    public GenDisMixClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, double d, double[] dArr, DifferentiableSequenceScore... differentiableSequenceScoreArr) throws CloneNotSupportedException {
        super(genDisMixClassifierParameterSet, d, differentiableSequenceScoreArr);
        setWeights(dArr);
        setPrior(logPrior);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, double d, double[] dArr, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        this(genDisMixClassifierParameterSet, logPrior, d, dArr, (DifferentiableSequenceScore[]) differentiableStatisticalModelArr);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, double[] dArr, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        this(genDisMixClassifierParameterSet, logPrior, Double.NaN, dArr, differentiableStatisticalModelArr);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, double d, double d2, double d3, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        this(genDisMixClassifierParameterSet, logPrior, Double.NaN, new double[]{d2, d, d3}, differentiableStatisticalModelArr);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, LearningPrinciple learningPrinciple, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        this(genDisMixClassifierParameterSet, logPrior, Double.NaN, LearningPrinciple.getBeta(learningPrinciple), differentiableStatisticalModelArr);
    }

    public GenDisMixClassifier(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public GenDisMixClassifier clone() throws CloneNotSupportedException {
        GenDisMixClassifier genDisMixClassifier = (GenDisMixClassifier) super.clone();
        genDisMixClassifier.prior = this.prior.getNewInstance();
        genDisMixClassifier.beta = (double[]) this.beta.clone();
        return genDisMixClassifier;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier
    public LogGenDisMixFunction getFunction(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        GenDisMixClassifierParameterSet genDisMixClassifierParameterSet = (GenDisMixClassifierParameterSet) this.params;
        return dataSetArr.length > 1 ? new LogGenDisMixFunction(genDisMixClassifierParameterSet.getNumberOfThreads(), this.score, dataSetArr, dArr, this.prior, this.beta, genDisMixClassifierParameterSet.shouldBeNormalized(), genDisMixClassifierParameterSet.useOnlyFreeParameter()) : new OneDataSetLogGenDisMixFunction(genDisMixClassifierParameterSet.getNumberOfThreads(), this.score, dataSetArr[0], dArr, this.prior, this.beta, genDisMixClassifierParameterSet.shouldBeNormalized(), genDisMixClassifierParameterSet.useOnlyFreeParameter());
    }

    public void setPrior(LogPrior logPrior) {
        if (logPrior != null) {
            this.prior = logPrior;
        } else {
            this.prior = DoesNothingLogPrior.defaultInstance;
        }
        this.hasBeenOptimized = false;
    }

    public void setWeights(double... dArr) throws IllegalArgumentException {
        this.beta = LearningPrinciple.checkWeights(dArr);
        this.hasBeenOptimized = false;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractClassifier
    protected String getXMLTag() {
        return XML_TAG;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public StringBuffer getFurtherClassifierInfos() {
        StringBuffer furtherClassifierInfos = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.beta, "beta");
        if (!(this.prior instanceof DoesNothingLogPrior)) {
            StringBuffer stringBuffer = new StringBuffer(ImagePreloader.DEFAULT_PRIORITY);
            stringBuffer.append("<prior>\n");
            XMLParser.appendObjectWithTags(stringBuffer, this.prior.getClass(), SVGConstants.SVG_CLASS_ATTRIBUTE);
            stringBuffer.append(this.prior.toXML());
            stringBuffer.append("\t</prior>\n");
            furtherClassifierInfos.append(stringBuffer);
        }
        return furtherClassifierInfos;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public void extractFurtherClassifierInfosFromXML(StringBuffer stringBuffer) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(stringBuffer);
        this.beta = LearningPrinciple.checkWeights((double[]) XMLParser.extractObjectForTags(stringBuffer, "beta", double[].class));
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "prior");
        if (extractForTag != null) {
            Class cls = (Class) XMLParser.extractObjectForTags(extractForTag, SVGConstants.SVG_CLASS_ATTRIBUTE, Class.class);
            try {
                this.prior = (LogPrior) cls.getConstructor(StringBuffer.class).newInstance(extractForTag);
            } catch (NoSuchMethodException e) {
                NonParsableException nonParsableException = new NonParsableException("You must provide a constructor " + cls.getSimpleName() + "(StringBuffer).");
                nonParsableException.setStackTrace(e.getStackTrace());
                throw nonParsableException;
            } catch (Exception e2) {
                NonParsableException nonParsableException2 = new NonParsableException("problem at " + cls.getSimpleName() + ": " + e2.getMessage());
                nonParsableException2.setStackTrace(e2.getStackTrace());
                throw nonParsableException2;
            }
        } else {
            this.prior = DoesNothingLogPrior.defaultInstance;
        }
        if (this.beta[2] > 0.0d) {
            try {
                this.prior.set(((GenDisMixClassifierParameterSet) this.params).useOnlyFreeParameter(), this.score);
            } catch (Exception e3) {
                NonParsableException nonParsableException3 = new NonParsableException("problem when setting the kind of parameter: " + e3.getMessage());
                nonParsableException3.setStackTrace(e3.getStackTrace());
                throw nonParsableException3;
            }
        }
    }

    public static GenDisMixClassifier[] create(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, double[] dArr, DifferentiableStatisticalModel[]... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        int i = 1;
        int[] iArr = new int[differentiableStatisticalModelArr.length];
        int[] iArr2 = new int[differentiableStatisticalModelArr.length];
        DifferentiableStatisticalModel[] differentiableStatisticalModelArr2 = new DifferentiableStatisticalModel[differentiableStatisticalModelArr.length];
        for (int i2 = 0; i2 < differentiableStatisticalModelArr.length; i2++) {
            i *= differentiableStatisticalModelArr[i2].length;
            iArr2[i2] = differentiableStatisticalModelArr[i2].length - 1;
        }
        GenDisMixClassifier[] genDisMixClassifierArr = new GenDisMixClassifier[i];
        int length = differentiableStatisticalModelArr2.length - 1;
        for (int i3 = 0; i3 < genDisMixClassifierArr.length; i3++) {
            for (int i4 = 0; i4 < differentiableStatisticalModelArr2.length; i4++) {
                differentiableStatisticalModelArr2[i4] = differentiableStatisticalModelArr[i4][iArr[i4]];
            }
            genDisMixClassifierArr[i3] = new GenDisMixClassifier(genDisMixClassifierParameterSet, logPrior, dArr, differentiableStatisticalModelArr2);
            int i5 = 0;
            while (i5 < length && iArr[i5] == iArr2[i5]) {
                int i6 = i5;
                i5++;
                iArr[i6] = 0;
            }
            int i7 = i5;
            iArr[i7] = iArr[i7] + 1;
        }
        return genDisMixClassifierArr;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractClassifier
    public String getInstanceName() {
        return String.valueOf(super.getInstanceName()) + " with weights=" + Arrays.toString(this.beta) + ((this.prior == null || this.prior == DoesNothingLogPrior.defaultInstance) ? "" : " and with " + this.prior.getInstanceName());
    }

    public int getNumberOfThreads() {
        return ((GenDisMixClassifierParameterSet) this.params).getNumberOfThreads();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer(this.score.length * RichSequenceBuilderFactory.THRESHOLD_VALUE);
        for (int i = 0; i < this.score.length; i++) {
            stringBuffer.append(String.valueOf("function ") + i);
            stringBuffer.append("\n" + this.score[i].toString() + "\n");
        }
        stringBuffer.append("class weights: ");
        for (int i2 = 0; i2 < getNumberOfClasses(); i2++) {
            stringBuffer.append(String.valueOf(getClassWeight(i2)) + " ");
        }
        stringBuffer.append("\n");
        return stringBuffer.toString();
    }

    public void setNumberOfThreads(int i) throws SimpleParameter.IllegalValueException {
        ((GenDisMixClassifierParameterSet) this.params).setNumberOfThreads(i);
    }
}
