package de.jstacs.classifiers.trainSMBased;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;
import org.biojavax.bio.seq.Position;

/* loaded from: input_file:de/jstacs/classifiers/trainSMBased/TrainSMBasedClassifier.class */
public class TrainSMBasedClassifier extends AbstractScoreBasedClassifier {
    protected TrainableStatisticalModel[] models;
    private static final String XML_TAG = "TrainSMBasedClassifier";

    public static int getPossibleLength(TrainableStatisticalModel... trainableStatisticalModelArr) throws IllegalArgumentException {
        int i = 0;
        int i2 = 0;
        while (i2 < trainableStatisticalModelArr.length) {
            int i3 = i2;
            i2++;
            int length = trainableStatisticalModelArr[i3].getLength();
            if (length != 0 && length != i) {
                if (i != 0) {
                    throw new IllegalArgumentException("The models can't be used for one classifier. Since at least one model has length " + i + ", while another has length " + length + Position.IN_RANGE);
                }
                i = length;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TrainSMBasedClassifier(boolean z, TrainableStatisticalModel... trainableStatisticalModelArr) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        super(trainableStatisticalModelArr[0].getAlphabetContainer(), getPossibleLength(trainableStatisticalModelArr), trainableStatisticalModelArr.length, -Math.log(trainableStatisticalModelArr.length));
        int checkAndSetModels = checkAndSetModels(trainableStatisticalModelArr, z);
        if (checkAndSetModels <= 0) {
            throw new IllegalArgumentException("Check length and AlphabetContainer of model " + ((-1) * checkAndSetModels) + Position.IN_RANGE);
        }
    }

    public TrainSMBasedClassifier(TrainableStatisticalModel... trainableStatisticalModelArr) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        this(true, trainableStatisticalModelArr);
    }

    public TrainSMBasedClassifier(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    /* renamed from: clone */
    public TrainSMBasedClassifier m57clone() throws CloneNotSupportedException {
        TrainSMBasedClassifier trainSMBasedClassifier = (TrainSMBasedClassifier) super.m57clone();
        trainSMBasedClassifier.models = (TrainableStatisticalModel[]) ArrayHandler.clone(this.models);
        return trainSMBasedClassifier;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public ResultSet getCharacteristics() throws Exception {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < this.models.length; i++) {
            ResultSet characteristics = this.models[i].getCharacteristics();
            if (characteristics != null && characteristics.getNumberOfResults() > 0) {
                linkedList.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
                for (int i2 = 0; i2 < characteristics.getNumberOfResults(); i2++) {
                    linkedList.add(characteristics.getResultAt(i2));
                }
            }
        }
        linkedList.add(new StorableResult("classifer", "the xml representation of the classifier", this));
        return new ResultSet(linkedList);
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public String getInstanceName() {
        return "model-based classifier";
    }

    public TrainableStatisticalModel getModel(int i) throws CloneNotSupportedException {
        return this.models[i].mo116clone();
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < this.models.length; i++) {
            NumericalResultSet numericalCharacteristics = this.models[i].getNumericalCharacteristics();
            if (numericalCharacteristics != null && numericalCharacteristics.getNumberOfResults() > 0) {
                linkedList.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
                for (int i2 = 0; i2 < numericalCharacteristics.getNumberOfResults(); i2++) {
                    linkedList.add(numericalCharacteristics.getResultAt(i2));
                }
            }
        }
        return new NumericalResultSet((LinkedList<? extends NumericalResult>) linkedList);
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public boolean isInitialized() {
        int i = 0;
        while (i < this.models.length && this.models[i].isInitialized()) {
            i++;
        }
        return i == this.models.length;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public void train(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (dArr != null && dataSetArr.length != dArr.length) {
            throw new IllegalArgumentException("data and weights do not match");
        }
        if (this.models.length != dataSetArr.length) {
            throw new ClassDimensionException();
        }
        double[] dArr2 = new double[this.models.length];
        for (int i = 0; i < this.models.length; i++) {
            if (dArr == null || dArr[i] == null) {
                this.models[i].train(dataSetArr[i]);
            } else {
                this.models[i].train(dataSetArr[i], dArr[i]);
            }
            if (getLength() > 0) {
                dArr2[i] = Math.log(dataSetArr[i].getNumberOfElementsWithLength(getLength(), dArr == null ? null : dArr[i]));
            } else {
                dArr2[i] = (dArr == null || dArr[i] == null) ? Math.log(dataSetArr[i].getNumberOfElements()) : Math.log(ToolBox.sum(dArr[i]));
            }
        }
        double logSum = Normalisation.getLogSum(dArr2);
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] - logSum;
        }
        setClassWeights(false, dArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public StringBuffer getFurtherClassifierInfos() {
        StringBuffer furtherClassifierInfos = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.models, "models");
        return furtherClassifierInfos;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier
    public double getScore(Sequence sequence, int i, boolean z) throws Exception {
        if (z) {
            check(sequence);
        }
        return this.models[i].getLogProbFor(sequence) + getClassWeight(i);
    }

    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier
    public double[] getScores(DataSet dataSet) throws Exception {
        if (getNumberOfClasses() != 2) {
            throw new OperationNotSupportedException("This method is only for 2-class-classifiers.");
        }
        if (dataSet == null) {
            return new double[0];
        }
        check(dataSet);
        double[] logScoreFor = this.models[0].getLogScoreFor(dataSet);
        double[] logScoreFor2 = this.models[1].getLogScoreFor(dataSet);
        double classWeight = getClassWeight(0);
        double classWeight2 = getClassWeight(1);
        for (int i = 0; i < logScoreFor.length; i++) {
            int i2 = i;
            logScoreFor[i2] = logScoreFor[i2] + (classWeight - (logScoreFor2[i] + classWeight2));
        }
        return logScoreFor;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public byte[] classify(DataSet dataSet) throws Exception {
        check(dataSet);
        double[] logScoreFor = this.models[0].getLogScoreFor(dataSet);
        double[] dArr = new double[logScoreFor.length];
        byte[] bArr = new byte[logScoreFor.length];
        double classWeight = getClassWeight(0);
        for (int i = 0; i < logScoreFor.length; i++) {
            int i2 = i;
            logScoreFor[i2] = logScoreFor[i2] + classWeight;
        }
        byte b = 1;
        while (true) {
            byte b2 = b;
            if (b2 >= getNumberOfClasses()) {
                return bArr;
            }
            double classWeight2 = getClassWeight(b2);
            this.models[b2].getLogScoreFor(dataSet, dArr);
            for (int i3 = 0; i3 < logScoreFor.length; i3++) {
                if (dArr[i3] + classWeight2 > logScoreFor[i3]) {
                    logScoreFor[i3] = dArr[i3] + classWeight2;
                    bArr[i3] = b2;
                }
            }
            b = (byte) (b2 + 1);
        }
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    protected String getXMLTag() {
        return XML_TAG;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public void extractFurtherClassifierInfosFromXML(StringBuffer stringBuffer) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(stringBuffer);
        try {
            int checkAndSetModels = checkAndSetModels((TrainableStatisticalModel[]) XMLParser.extractObjectForTags(stringBuffer, "models", TrainableStatisticalModel[].class), false);
            if (checkAndSetModels <= 0) {
                throw new NonParsableException("Check length and AlphabetContainer of model " + ((-1) * checkAndSetModels) + Position.IN_RANGE);
            }
        } catch (CloneNotSupportedException e) {
            NonParsableException nonParsableException = new NonParsableException("Clone not supported: " + e.getMessage());
            nonParsableException.setStackTrace(e.getStackTrace());
            throw nonParsableException;
        }
    }

    private int checkAndSetModels(TrainableStatisticalModel[] trainableStatisticalModelArr, boolean z) throws CloneNotSupportedException {
        int i = 0;
        int length = getLength();
        AlphabetContainer alphabetContainer = getAlphabetContainer();
        this.models = new TrainableStatisticalModel[trainableStatisticalModelArr.length];
        while (i < trainableStatisticalModelArr.length) {
            int length2 = trainableStatisticalModelArr[i].getLength();
            if ((length2 != 0 && length != length2) || !trainableStatisticalModelArr[i].getAlphabetContainer().checkConsistency(alphabetContainer)) {
                return -i;
            }
            if (z) {
                int i2 = i;
                int i3 = i;
                i++;
                this.models[i2] = trainableStatisticalModelArr[i3].mo116clone();
            } else {
                int i4 = i;
                int i5 = i;
                i++;
                this.models[i4] = trainableStatisticalModelArr[i5];
            }
        }
        return i;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] categoricalResultArr = new CategoricalResult[this.models.length + 1];
        categoricalResultArr[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", getInstanceName());
        int i = 0;
        while (i < this.models.length) {
            int i2 = i + 1;
            String str = "class info " + i;
            int i3 = i;
            i++;
            categoricalResultArr[i2] = new CategoricalResult(str, "some information about the class", this.models[i3].getInstanceName());
        }
        return categoricalResultArr;
    }

    public String toString() {
        return String.valueOf(Arrays.toString(getClassWeights())) + AbstractFormatter.DEFAULT_ROW_SEPARATOR + Arrays.toString(this.models);
    }
}
