/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.trainSMBased;

import de.jstacs.NotTrainedException;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;

public class TrainSMBasedClassifier
extends AbstractScoreBasedClassifier {
    protected TrainableStatisticalModel[] models;
    private static final String XML_TAG = "TrainSMBasedClassifier";

    public static int getPossibleLength(TrainableStatisticalModel ... models) throws IllegalArgumentException {
        int length = 0;
        int i = 0;
        while (i < models.length) {
            int l;
            if ((l = models[i++].getLength()) == 0 || l == length) continue;
            if (length == 0) {
                length = l;
                continue;
            }
            throw new IllegalArgumentException("The models can't be used for one classifier. Since at least one model has length " + length + ", while another has length " + l + ".");
        }
        return length;
    }

    protected TrainSMBasedClassifier(boolean cloneModels, TrainableStatisticalModel ... models) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        super(models[0].getAlphabetContainer(), TrainSMBasedClassifier.getPossibleLength(models), models.length, -Math.log(models.length));
        int i = this.checkAndSetModels(models, cloneModels);
        if (i <= 0) {
            throw new IllegalArgumentException("Check length and AlphabetContainer of model " + -1 * i + ".");
        }
    }

    public TrainSMBasedClassifier(TrainableStatisticalModel ... models) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        this(true, models);
    }

    public TrainSMBasedClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public TrainSMBasedClassifier clone() throws CloneNotSupportedException {
        TrainSMBasedClassifier clone = (TrainSMBasedClassifier)super.clone();
        clone.models = (TrainableStatisticalModel[])ArrayHandler.clone((Cloneable[])this.models);
        return clone;
    }

    @Override
    public ResultSet getCharacteristics() throws Exception {
        LinkedList<Result> list = new LinkedList<Result>();
        int i = 0;
        while (i < this.models.length) {
            ResultSet set = this.models[i].getCharacteristics();
            if (set != null && set.getNumberOfResults() > 0) {
                list.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
                int j = 0;
                while (j < set.getNumberOfResults()) {
                    list.add(set.getResultAt(j));
                    ++j;
                }
            }
            ++i;
        }
        list.add(new StorableResult("classifer", "the xml representation of the classifier", this));
        return new ResultSet(list);
    }

    @Override
    public String getInstanceName() {
        return "model-based classifier";
    }

    public TrainableStatisticalModel getModel(int classIndex) throws CloneNotSupportedException {
        return this.models[classIndex].clone();
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        LinkedList<NumericalResult> list = new LinkedList<NumericalResult>();
        int i = 0;
        while (i < this.models.length) {
            NumericalResultSet set = this.models[i].getNumericalCharacteristics();
            if (set != null && set.getNumberOfResults() > 0) {
                list.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
                int j = 0;
                while (j < set.getNumberOfResults()) {
                    list.add(set.getResultAt(j));
                    ++j;
                }
            }
            ++i;
        }
        return new NumericalResultSet(list);
    }

    @Override
    public boolean isInitialized() {
        int i = 0;
        while (i < this.models.length && this.models[i].isInitialized()) {
            ++i;
        }
        return i == this.models.length;
    }

    @Override
    public void train(DataSet[] s, double[][] weights) throws Exception {
        if (weights != null && s.length != weights.length) {
            throw new IllegalArgumentException("data and weights do not match");
        }
        if (this.models.length != s.length) {
            throw new ClassDimensionException();
        }
        double[] c = new double[this.models.length];
        int i = 0;
        while (i < this.models.length) {
            if (weights == null || weights[i] == null) {
                this.models[i].train(s[i]);
            } else {
                this.models[i].train(s[i], weights[i]);
            }
            c[i] = this.getLength() > 0 ? Math.log(s[i].getNumberOfElementsWithLength(this.getLength(), weights == null ? null : weights[i])) : (weights == null || weights[i] == null ? Math.log(s[i].getNumberOfElements()) : Math.log(ToolBox.sum(weights[i])));
            ++i;
        }
        double norm = Normalisation.getLogSum(c);
        int i2 = 0;
        while (i2 < c.length) {
            int n = i2++;
            c[n] = c[n] - norm;
        }
        this.setClassWeights(false, c);
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(xml, this.models, "models");
        return xml;
    }

    @Override
    protected double getScore(Sequence seq, int i, boolean check) throws Exception {
        if (check) {
            this.check(seq);
        }
        return this.models[i].getLogProbFor(seq) + this.getClassWeight(i);
    }

    public double[] getLogLikelihoodRatio(Sequence seq) throws Exception {
        if (!this.isInitialized()) {
            throw new NotTrainedException("The classifier is not trained yet.");
        }
        if (this.getNumberOfClasses() != 2) {
            throw new IllegalArgumentException("This method can only be used for binary classifiers.");
        }
        if (!this.getAlphabetContainer().checkConsistency(seq.getAlphabetContainer())) {
            throw new IllegalArgumentException("The sequence is not defined over the correct alphabets.");
        }
        double[] score = new double[seq.getLength() - this.getLength() + 1];
        double constant = this.getClassWeight(0) - this.getClassWeight(1);
        int p = 0;
        while (p < score.length) {
            score[p] = this.models[0].getLogProbFor(seq, p) - this.models[1].getLogProbFor(seq, p) + constant;
            ++p;
        }
        return score;
    }

    @Override
    public double[] getScores(DataSet s) throws Exception {
        if (this.getNumberOfClasses() != 2) {
            throw new OperationNotSupportedException("This method is only for 2-class-classifiers.");
        }
        if (s == null) {
            return new double[0];
        }
        this.check(s);
        double[] score0 = this.models[0].getLogScoreFor(s);
        double[] score1 = this.models[1].getLogScoreFor(s);
        double c0 = this.getClassWeight(0);
        double c1 = this.getClassWeight(1);
        int i = 0;
        while (i < score0.length) {
            int n = i;
            score0[n] = score0[n] + (c0 - (score1[i] + c1));
            ++i;
        }
        return score0;
    }

    @Override
    public byte[] classify(DataSet s) throws Exception {
        this.check(s);
        double[] best = this.models[0].getLogScoreFor(s);
        double[] current = new double[best.length];
        byte[] clazz = new byte[best.length];
        double cw = this.getClassWeight(0);
        int i = 0;
        while (i < best.length) {
            int n = i++;
            best[n] = best[n] + cw;
        }
        int j = 1;
        while (j < this.getNumberOfClasses()) {
            cw = this.getClassWeight(j);
            this.models[j].getLogScoreFor(s, current);
            i = 0;
            while (i < best.length) {
                if (current[i] + cw > best[i]) {
                    best[i] = current[i] + cw;
                    clazz[i] = j;
                }
                ++i;
            }
            j = (byte)(j + 1);
        }
        return clazz;
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        int i;
        super.extractFurtherClassifierInfosFromXML(xml);
        try {
            i = this.checkAndSetModels(XMLParser.extractObjectForTags(xml, "models", TrainableStatisticalModel[].class), false);
        }
        catch (CloneNotSupportedException e) {
            NonParsableException n = new NonParsableException("Clone not supported: " + e.getMessage());
            n.setStackTrace(e.getStackTrace());
            throw n;
        }
        if (i <= 0) {
            throw new NonParsableException("Check length and AlphabetContainer of model " + -1 * i + ".");
        }
    }

    private int checkAndSetModels(TrainableStatisticalModel[] models, boolean clone) throws CloneNotSupportedException {
        int i = 0;
        int length = this.getLength();
        AlphabetContainer abc = this.getAlphabetContainer();
        this.models = new TrainableStatisticalModel[models.length];
        while (i < models.length) {
            int l = models[i].getLength();
            if (l != 0 && length != l || !models[i].getAlphabetContainer().checkConsistency(abc)) {
                return -i;
            }
            this.models[i] = clone ? models[i++].clone() : models[i++];
        }
        return i;
    }

    @Override
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] res = new CategoricalResult[this.models.length + 1];
        res[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", this.getInstanceName());
        int i = 0;
        while (i < this.models.length) {
            res[i + 1] = new CategoricalResult("class info " + i, "some information about the class", this.models[i++].getInstanceName());
        }
        return res;
    }

    public String toString() {
        StringBuffer sb = new StringBuffer(this.models.length * 5000);
        String heading = "model ";
        int i = 0;
        while (i < this.models.length) {
            sb.append(String.valueOf(heading) + i);
            sb.append("\n" + this.models[i].toString() + "\n");
            ++i;
        }
        sb.append("class weights: ");
        i = 0;
        while (i < this.getNumberOfClasses()) {
            sb.append(String.valueOf(this.getClassWeight(i)) + " ");
            ++i;
        }
        sb.append("\n");
        return sb.toString();
    }
}

