/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifier;

import de.jstacs.DataType;
import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.classifier.AbstractClassifier;
import de.jstacs.classifier.ClassDimensionException;
import de.jstacs.classifier.ConfusionMatrix;
import de.jstacs.classifier.MeasureParameters;
import de.jstacs.classifier.ScoreBasedPerformanceMeasureDefinitions;
import de.jstacs.classifier.utils.PValueComputation;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.ParameterSet;
import de.jstacs.parameters.RangeParameter;
import de.jstacs.results.ImageResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.Result;
import de.jstacs.utils.REnvironment;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public abstract class AbstractScoreBasedClassifier
extends AbstractClassifier {
    private double[] classWeights;

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int classes) {
        this(abc, 0, classes, 0.0);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int classes, double classWeight) {
        this(abc, 0, classes, classWeight);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int length, int classes) {
        this(abc, length, classes, 0.0);
    }

    public AbstractScoreBasedClassifier(AlphabetContainer abc, int length, int classes, double classWeight) {
        super(abc, length);
        if (classes < 2) {
            throw new IllegalArgumentException("You should have at least 2 classes.");
        }
        this.createDefaultClassWeights(classes, classWeight);
    }

    public AbstractScoreBasedClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public AbstractScoreBasedClassifier clone() throws CloneNotSupportedException {
        AbstractScoreBasedClassifier erg = (AbstractScoreBasedClassifier)super.clone();
        erg.classWeights = (double[])this.classWeights.clone();
        return erg;
    }

    @Override
    public byte classify(Sequence seq) throws Exception {
        return this.classify(seq, true);
    }

    @Override
    protected LinkedList<? extends Result> getResults(Sample[] s, MeasureParameters params, boolean exceptionIfNotComputeable, boolean all) throws Exception {
        double sn;
        if (s.length != this.getNumberOfClasses()) {
            throw new ClassDimensionException();
        }
        if (s.length != 2) {
            return super.getResults(s, params, exceptionIfNotComputeable, all);
        }
        LinkedList<Result> list = new LinkedList<Result>();
        double[][] sortedScores = null;
        int numSelected = params.getNumberOfValues();
        if (numSelected > 0) {
            sortedScores = this.getSortedScore(s);
        }
        if (params.isSelected(MeasureParameters.Measure.ClassificationRate)) {
            list.add(new NumericalResult(MeasureParameters.Measure.ClassificationRate.getNameString(), MeasureParameters.Measure.ClassificationRate.getCommentString(), ScoreBasedPerformanceMeasureDefinitions.getClassificationRateFor2Classes(sortedScores[0], sortedScores[1])));
        }
        if (params.isSelected(MeasureParameters.Measure.Sensitivity)) {
            double sp = (Double)((ParameterSet)params.getValueFor(MeasureParameters.Measure.Sensitivity.getNameString())).getParameterAt(0).getValue();
            ScoreBasedPerformanceMeasureDefinitions.ThresholdMeasurePair sn2 = ScoreBasedPerformanceMeasureDefinitions.getSensitivityForSpecificity(sortedScores[0], sortedScores[1], sp);
            list.add(new NumericalResult(MeasureParameters.Measure.Sensitivity.getNameString(), MeasureParameters.Measure.Sensitivity.getCommentString() + " = " + sp, sn2.getMeasure()));
            list.add(new NumericalResult("Threshold for sensitivity", "The threshold for the sensitivity for a specificity of " + sp, sn2.getThreshold()));
        }
        if (params.isSelected(MeasureParameters.Measure.FalsePositiveRate)) {
            sn = (Double)((ParameterSet)params.getValueFor(MeasureParameters.Measure.FalsePositiveRate.getNameString())).getParameterAt(0).getValue();
            ScoreBasedPerformanceMeasureDefinitions.ThresholdMeasurePair fpr = ScoreBasedPerformanceMeasureDefinitions.getFPRForSensitivity(sortedScores[0], sortedScores[1], sn);
            list.add(new NumericalResult(MeasureParameters.Measure.FalsePositiveRate.getNameString(), MeasureParameters.Measure.FalsePositiveRate.getCommentString() + " = " + sn, fpr.getMeasure()));
            list.add(new NumericalResult("Threshold for false positive rate", "The threshold for the false positive rate for a fixed sensitivity of " + sn, fpr.getThreshold()));
        }
        if (params.isSelected(MeasureParameters.Measure.PositivePredictiveValue)) {
            sn = (Double)((ParameterSet)params.getValueFor(MeasureParameters.Measure.PositivePredictiveValue.getNameString())).getParameterAt(0).getValue();
            ScoreBasedPerformanceMeasureDefinitions.ThresholdMeasurePair ppv = ScoreBasedPerformanceMeasureDefinitions.getPPVForSensitivity(sortedScores[0], sortedScores[1], sn);
            list.add(new NumericalResult(MeasureParameters.Measure.PositivePredictiveValue.getNameString(), MeasureParameters.Measure.PositivePredictiveValue.getCommentString() + " = " + sn, ppv.getMeasure()));
            list.add(new NumericalResult("Threshold for positive predictive value", "The threshold for the positive predictive value for a fixed sensitivity of " + sn, ppv.getThreshold()));
        }
        ArrayList<Object> list2 = null;
        if (params.isSelected(MeasureParameters.Measure.AreaUnderROCCurve)) {
            if (all && params.isSelected(MeasureParameters.Measure.AreaUnderROCCurve)) {
                list2 = new ArrayList<double[]>(s[0].getNumberOfElements() / 2);
            }
            list.add(new NumericalResult(MeasureParameters.Measure.AreaUnderROCCurve.getNameString(), MeasureParameters.Measure.AreaUnderROCCurve.getCommentString(), ScoreBasedPerformanceMeasureDefinitions.getAUC_ROC(sortedScores[0], sortedScores[1], list2)));
        }
        ArrayList<Object> list3 = null;
        if (params.isSelected(MeasureParameters.Measure.AreaUnderPrecisionRecallCurve)) {
            if (all && params.isSelected(MeasureParameters.Measure.AreaUnderPrecisionRecallCurve)) {
                list3 = new ArrayList<double[]>(s[0].getNumberOfElements() / 2);
            }
            list.add(new NumericalResult(MeasureParameters.Measure.AreaUnderPrecisionRecallCurve.getNameString(), MeasureParameters.Measure.AreaUnderPrecisionRecallCurve.getCommentString(), ScoreBasedPerformanceMeasureDefinitions.getAUC_PR(sortedScores[0], sortedScores[1], list3)));
        }
        if (params.isSelected(MeasureParameters.Measure.MaximumCorrelationCoefficient)) {
            ScoreBasedPerformanceMeasureDefinitions.ThresholdMeasurePair maxCC = ScoreBasedPerformanceMeasureDefinitions.getMaxOfCC(sortedScores[0], sortedScores[1]);
            list.add(new NumericalResult(MeasureParameters.Measure.MaximumCorrelationCoefficient.getNameString(), MeasureParameters.Measure.MaximumCorrelationCoefficient.getCommentString(), maxCC.getMeasure()));
            list.add(new NumericalResult("Threshold for maximum correlation coefficient", "The threshold for the maximum correlation coefficient of all possible correlation coefficients", maxCC.getThreshold()));
        }
        if (params.isSelected(MeasureParameters.Measure.PartialROCCurve)) {
            RangeParameter specs = (RangeParameter)((ParameterSet)params.getValueFor(MeasureParameters.Measure.PartialROCCurve.getNameString())).getParameterAt(0);
            double[][] val = ScoreBasedPerformanceMeasureDefinitions.getPartialROC(sortedScores[0], sortedScores[1], specs);
            specs.resetToFirst();
            int i = 0;
            do {
                list.add(new NumericalResult("" + val[0][i], "The sensitivity for a specificity of " + val[0][i], val[1][i]));
                list.add(new NumericalResult("Threshold for a specificity of " + val[0][i], "The threshold for the sensitivity for a specificity of " + val[0][i], val[2][i]));
                ++i;
            } while (specs.next());
        }
        if (all) {
            if (params.isSelected(MeasureParameters.Measure.RecieverOperatingCharacteristicCurve)) {
                if (list2 == null) {
                    list2 = new ArrayList(s[0].getNumberOfElements() / 2);
                    ScoreBasedPerformanceMeasureDefinitions.getAUC_ROC(sortedScores[0], sortedScores[1], list2);
                }
                list.add(new DoubleTableResult(MeasureParameters.Measure.RecieverOperatingCharacteristicCurve.getNameString(), MeasureParameters.Measure.RecieverOperatingCharacteristicCurve.getCommentString(), list2));
            }
            if (params.isSelected(MeasureParameters.Measure.PrecisionRecallCurve)) {
                if (list3 == null) {
                    list3 = new ArrayList(s[0].getNumberOfElements() / 2);
                    ScoreBasedPerformanceMeasureDefinitions.getAUC_PR(sortedScores[0], sortedScores[1], list3);
                }
                list.add(new DoubleTableResult(MeasureParameters.Measure.PrecisionRecallCurve.getNameString(), MeasureParameters.Measure.PrecisionRecallCurve.getCommentString(), list3));
            }
        }
        return list;
    }

    public double[] getClassWeights() {
        return (double[])this.classWeights.clone();
    }

    @Override
    public int getNumberOfClasses() {
        return this.classWeights.length;
    }

    public double getScore(Sequence seq, int i) throws Exception {
        return this.getScore(seq, i, true);
    }

    public final void setClassWeights(boolean add, double ... weights) throws ClassDimensionException {
        int c = this.getNumberOfClasses();
        if (weights == null || c != weights.length) {
            throw new ClassDimensionException();
        }
        if (add) {
            for (int i = 0; i < this.classWeights.length; ++i) {
                int n = i;
                this.classWeights[n] = this.classWeights[n] + weights[i];
            }
        } else {
            for (int i = 0; i < this.classWeights.length; ++i) {
                this.classWeights[i] = weights[i];
            }
        }
    }

    public final void setThresholdClassWeights(boolean add, double t) throws OperationNotSupportedException {
        int c = this.getNumberOfClasses();
        if (c != 2) {
            throw new OperationNotSupportedException();
        }
        if (this.classWeights == null) {
            this.classWeights = new double[2];
        }
        double logP = -Math.log1p(Math.exp(t));
        if (add) {
            this.classWeights[0] = this.classWeights[0] + logP;
            this.classWeights[1] = this.classWeights[1] + (t + logP);
        } else {
            this.classWeights[0] = logP;
            this.classWeights[1] = t + logP;
        }
    }

    @Override
    public ConfusionMatrix test(Sample ... testData) throws Exception {
        if (testData.length != this.getNumberOfClasses()) {
            throw new ClassDimensionException();
        }
        ConfusionMatrix matrix = new ConfusionMatrix(testData.length);
        for (int i = 0; i < testData.length; ++i) {
            if (testData[i] == null) continue;
            this.check(testData[i]);
            Sample.ElementEnumerator ei = new Sample.ElementEnumerator(testData[i]);
            while (ei.hasMoreElements()) {
                matrix.add(i, this.classify(ei.nextElement(), false));
            }
        }
        return matrix;
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = new StringBuffer(300);
        XMLParser.appendDoubleArrayWithTags(xml, this.classWeights, "classWeights");
        return xml;
    }

    protected void check(Sample s) throws NotTrainedException, IllegalArgumentException {
        if (!this.isTrained()) {
            throw new NotTrainedException("The classifier is not trained yet.");
        }
        int length = this.getLength();
        if (length != 0 && s.getElementLength() != length) {
            throw new IllegalArgumentException("The sequences have not the correct length.");
        }
        if (!this.setNewAlphabetContainerInstance(s.getAlphabetContainer())) {
            throw new IllegalArgumentException("The sequences are not defined over the correct alphabets.");
        }
    }

    protected void check(Sequence seq) throws NotTrainedException, IllegalArgumentException {
        if (!this.isTrained()) {
            throw new NotTrainedException("The classifier is not trained yet.");
        }
        int length = this.getLength();
        if (length != 0 && seq.getLength() != length) {
            throw new IllegalArgumentException("The sequence has not the correct length.");
        }
        if (!this.getAlphabetContainer().checkConsistency(seq.getAlphabetContainer())) {
            throw new IllegalArgumentException("The sequence is not defined over the correct alphabets.");
        }
    }

    protected byte classify(Sequence seq, boolean check) throws Exception {
        if (check) {
            this.check(seq);
        }
        int clazz = 0;
        double max = this.getScore(seq, clazz, false);
        for (int i = 1; i < this.getNumberOfClasses(); i = (int)((byte)(i + 1))) {
            double current = this.getScore(seq, i, false);
            if (!(current > max)) continue;
            max = current;
            clazz = i;
        }
        return (byte)clazz;
    }

    protected void createDefaultClassWeights(int classes, double val) throws IllegalArgumentException {
        if (classes < 2) {
            throw new IllegalArgumentException();
        }
        this.classWeights = new double[classes];
        Arrays.fill(this.classWeights, val);
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        this.classWeights = XMLParser.extractDoubleArrayForTag(xml, "classWeights");
    }

    protected double getClassWeight(int index) {
        return this.classWeights[index];
    }

    protected abstract double getScore(Sequence var1, int var2, boolean var3) throws IllegalArgumentException, NotTrainedException, Exception;

    public double[] getScores(Sample s) throws Exception {
        if (this.classWeights.length != 2) {
            throw new OperationNotSupportedException("This method is only for 2-class-classifiers.");
        }
        if (s == null) {
            return new double[0];
        }
        this.check(s);
        double[] score = new double[s.getNumberOfElements()];
        Sample.ElementEnumerator ei = new Sample.ElementEnumerator(s);
        for (int i = 0; i < score.length; ++i) {
            Sequence seq = ei.nextElement();
            score[i] = this.getScore(seq, 0, false) - this.getScore(seq, 1, false);
            if (!Double.isNaN(score[i])) continue;
            throw new IllegalArgumentException("Could not classify sequence " + i + ": " + seq + "\nfg: " + this.getScore(seq, 0, false) + "\nbg: " + this.getScore(seq, 1, false));
        }
        return score;
    }

    public double getPValue(Sequence candidate, Sample bg) throws Exception {
        double[] scores = this.createStatistic(bg);
        return PValueComputation.getPValue(scores, this.getScore(candidate, 0) - this.getScore(candidate, 1));
    }

    public double[] getPValue(Sample candidates, Sample bg) throws Exception {
        double[] scores = this.createStatistic(bg);
        double[] pVal = new double[candidates.getNumberOfElements()];
        for (int i = 0; i < pVal.length; ++i) {
            Sequence candidate = candidates.getElementAt(i);
            pVal[i] = PValueComputation.getPValue(scores, this.getScore(candidate, 0) - this.getScore(candidate, 1));
        }
        return pVal;
    }

    private double[] createStatistic(Sample bg) throws Exception {
        double[] scores = this.getScores(bg);
        Arrays.sort(scores);
        return scores;
    }

    private double[][] getSortedScore(Sample[] s) throws Exception {
        double[][] scores = new double[][]{this.getScores(s[0]), this.getScores(s[1])};
        Arrays.sort(scores[0]);
        Arrays.sort(scores[1]);
        return scores;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class DoubleTableResult
    extends Result {
        private double[][] content;
        private static final String XML_TAG = "DoubleTableResult";

        private DoubleTableResult(String name, String comment, ArrayList<double[]> list) {
            super(name, comment, DataType.LIST);
            this.content = (double[][])list.toArray((T[])new double[0][0]);
        }

        public DoubleTableResult(StringBuffer representation) throws NonParsableException {
            super(representation);
        }

        @Override
        protected void fromXML(StringBuffer representation) throws NonParsableException {
            StringBuffer xml = XMLParser.extractForTag(representation, XML_TAG);
            this.extractMainInfo(xml);
            this.content = XMLParser.extractDouble2ArrayForTag(xml, "content");
        }

        public double[] getLine(int index) {
            return (double[])this.content[index].clone();
        }

        public int getNumberOfLines() {
            return this.content.length;
        }

        public String toString() {
            return "[table] \t " + this.name + " \t(" + this.comment + ")";
        }

        public double[][] getResult() {
            double[][] res = new double[this.content.length][];
            for (int i = 0; i < res.length; ++i) {
                res[i] = (double[])this.content[i].clone();
            }
            return res;
        }

        @Override
        public StringBuffer toXML() {
            StringBuffer xml = new StringBuffer(500 + this.content.length * this.content[0].length * 10);
            this.appendMainInfo(xml);
            XMLParser.appendDouble2ArrayWithTags(xml, this.content, "content");
            XMLParser.addTags(xml, XML_TAG);
            return xml;
        }

        public static final ImageResult plot(REnvironment e, DoubleTableResult ... dtr) throws Exception {
            int i;
            String opt = dtr[0].name;
            for (i = 1; i < dtr.length && dtr[i].name.equalsIgnoreCase(opt); ++i) {
            }
            if (i != dtr.length) {
                opt = null;
            }
            return new ImageResult(opt, "This plot shows the " + opt + ".", e.plot(DoubleTableResult.getPlotCommands(e, opt, dtr).toString()));
        }

        public static final StringBuffer getPlotCommands(REnvironment e, String plotOptions, DoubleTableResult ... dtr) throws Exception {
            int i = 0;
            while (i < dtr.length) {
                e.createMatrix("dtr" + i, dtr[i++].content);
            }
            if (plotOptions == null) {
                String string = plotOptions = dtr[0].name == null ? "" : dtr[0].name;
            }
            if (plotOptions.equals(MeasureParameters.Measure.RecieverOperatingCharacteristicCurve.getNameString())) {
                plotOptions = ", xlim=c(0, 1), ylim=c(0, 1), xlab=\"false positive rate\", ylab=\"true positive rate\", main=\"ROC\"";
            } else if (plotOptions.equals(MeasureParameters.Measure.PrecisionRecallCurve.getNameString())) {
                plotOptions = ", xlim=c(0, 1), ylim=c(0, 1), xlab=\"recall\", ylab=\"precision\", main=\"PR\"";
            } else if ((plotOptions = plotOptions.trim()).charAt(0) != ',') {
                plotOptions = ", " + plotOptions;
            }
            StringBuffer p = new StringBuffer(dtr.length * 200);
            p.append("plot( dtr0[,1], dtr0[,2], col=1, type=\"l\"" + plotOptions + " );");
            i = 1;
            while (i < dtr.length) {
                p.append("\nlines( dtr" + i + "[,1], dtr" + i + "[,2], col=" + ++i + " );");
            }
            return p;
        }
    }
}

