/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifier.modelBased;

import de.jstacs.NonParsableException;
import de.jstacs.classifier.AbstractScoreBasedClassifier;
import de.jstacs.classifier.ClassDimensionException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.models.AbstractModel;
import de.jstacs.models.Model;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;

public class ModelBasedClassifier
extends AbstractScoreBasedClassifier {
    protected Model[] models;
    private static final String XML_TAG = "ModelBasedClassifier";

    public static int getPossibleLength(Model ... models) throws IllegalArgumentException {
        int length = 0;
        int i = 0;
        while (i < models.length) {
            int l;
            if ((l = models[i++].getLength()) == 0 || l == length) continue;
            if (length == 0) {
                length = l;
                continue;
            }
            throw new IllegalArgumentException("The models can't be used for one classifier. Since at least one model has length " + length + ", while another has length " + l + ".");
        }
        return length;
    }

    protected ModelBasedClassifier(boolean cloneModels, Model ... models) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        super(models[0].getAlphabetContainer(), ModelBasedClassifier.getPossibleLength(models), models.length, -Math.log(models.length));
        int i = this.checkAndSetModels(models, cloneModels);
        if (i <= 0) {
            throw new IllegalArgumentException("Check length and AlphabetContainer of model " + -1 * i + ".");
        }
    }

    public ModelBasedClassifier(Model ... models) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        this(true, models);
    }

    public ModelBasedClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    public ModelBasedClassifier clone() throws CloneNotSupportedException {
        ModelBasedClassifier clone = (ModelBasedClassifier)super.clone();
        clone.models = (Model[])ArrayHandler.clone((Cloneable[])this.models);
        return clone;
    }

    public ResultSet getCharacteristics() throws Exception {
        LinkedList<Result> list = new LinkedList<Result>();
        for (int i = 0; i < this.models.length; ++i) {
            ResultSet set = this.models[i].getCharacteristics();
            if (set == null || set.getNumberOfResults() <= 0) continue;
            list.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
            for (int j = 0; j < set.getNumberOfResults(); ++j) {
                list.add(set.getResultAt(j));
            }
        }
        list.add(new StorableResult("classifer", "the xml representation of the classifier", this));
        return new ResultSet(list);
    }

    public String getInstanceName() {
        return "model-based classifier";
    }

    public Model getModel(int classIndex) throws CloneNotSupportedException {
        return this.models[classIndex].clone();
    }

    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        LinkedList<NumericalResult> list = new LinkedList<NumericalResult>();
        for (int i = 0; i < this.models.length; ++i) {
            NumericalResultSet set = this.models[i].getNumericalCharacteristics();
            if (set == null || set.getNumberOfResults() <= 0) continue;
            list.add(new NumericalResult("class index", "the index of the class that produces the following results", i));
            for (int j = 0; j < set.getNumberOfResults(); ++j) {
                list.add(set.getResultAt(j));
            }
        }
        return new NumericalResultSet(list);
    }

    public boolean isTrained() {
        int i;
        for (i = 0; i < this.models.length && this.models[i].isTrained(); ++i) {
        }
        return i == this.models.length;
    }

    public final boolean setNewAlphabetContainerInstance(AlphabetContainer abc) {
        if (super.setNewAlphabetContainerInstance(abc)) {
            for (int i = 0; i < this.models.length; ++i) {
                this.models[i].setNewAlphabetContainerInstance(abc);
            }
            return true;
        }
        return false;
    }

    public void train(Sample[] s, double[][] weights) throws Exception {
        if (weights != null && s.length != weights.length) {
            throw new IllegalArgumentException("data and weights do not match");
        }
        if (this.models.length != s.length) {
            throw new ClassDimensionException();
        }
        double[] c = new double[this.models.length];
        for (int i = 0; i < this.models.length; ++i) {
            if (weights == null || weights[i] == null) {
                this.models[i].train(s[i]);
                continue;
            }
            this.models[i].train(s[i], weights[i]);
        }
        this.setClassWeights(false, c);
    }

    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendStorableArrayWithTags(xml, this.models, "models");
        return xml;
    }

    protected double getScore(Sequence seq, int i, boolean check) throws Exception {
        if (check) {
            this.check(seq);
        }
        return this.models[i].getLogProbFor(seq) + this.getClassWeight(i);
    }

    public double[] getScores(Sample s) throws Exception {
        if (this.getNumberOfClasses() != 2) {
            throw new OperationNotSupportedException("This method is only for 2-class-classifiers.");
        }
        if (s == null) {
            return new double[0];
        }
        this.check(s);
        double[] score0 = this.models[0].getLogProbFor(s);
        double[] score1 = this.models[1].getLogProbFor(s);
        double c0 = this.getClassWeight(0);
        double c1 = this.getClassWeight(1);
        for (int i = 0; i < score0.length; ++i) {
            int n = i;
            score0[n] = score0[n] + (c0 - (score1[i] + c1));
        }
        return score0;
    }

    public byte[] classify(Sample s) throws Exception {
        this.check(s);
        double[] best = this.models[0].getLogProbFor(s);
        double[] current = new double[best.length];
        byte[] clazz = new byte[best.length];
        double cw = this.getClassWeight(0);
        int i = 0;
        while (i < best.length) {
            int n = i++;
            best[n] = best[n] + cw;
        }
        for (int j = 1; j < this.getNumberOfClasses(); j = (int)((byte)(j + 1))) {
            cw = this.getClassWeight(j);
            this.models[j].getLogProbFor(s, current);
            for (i = 0; i < best.length; ++i) {
                if (!(current[i] + cw > best[i])) continue;
                best[i] = current[i] + cw;
                clazz[i] = j;
            }
        }
        return clazz;
    }

    protected String getXMLTag() {
        return XML_TAG;
    }

    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        int i;
        super.extractFurtherClassifierInfosFromXML(xml);
        try {
            i = this.checkAndSetModels(ArrayHandler.cast(AbstractModel.class, XMLParser.extractStorableArrayForTag(xml, "models")), false);
        }
        catch (CloneNotSupportedException e) {
            NonParsableException n = new NonParsableException("Clone not supported: " + e.getMessage());
            n.setStackTrace(e.getStackTrace());
            throw n;
        }
        if (i <= 0) {
            throw new NonParsableException("Check length and AlphabetContainer of model " + -1 * i + ".");
        }
    }

    private int checkAndSetModels(Model[] models, boolean clone) throws CloneNotSupportedException {
        int i = 0;
        int length = this.getLength();
        AlphabetContainer abc = this.getAlphabetContainer();
        this.models = new Model[models.length];
        while (i < models.length) {
            int l = models[i].getLength();
            if (l != 0 && length != l || !models[i].setNewAlphabetContainerInstance(abc)) {
                return -i;
            }
            if (clone) {
                this.models[i] = models[i++].clone();
                continue;
            }
            this.models[i] = models[i++];
        }
        return i;
    }

    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] res = new CategoricalResult[this.models.length + 1];
        res[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", this.getInstanceName());
        int i = 0;
        while (i < this.models.length) {
            res[i + 1] = new CategoricalResult("class info " + i, "some information about the class", this.models[i++].getInstanceName());
        }
        return res;
    }
}

