/*
 * Decompiled with CFR 0.152.
 */
package projects.dimont.hts;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import javax.naming.OperationNotSupportedException;
import projects.dimont.AbstractSingleMotifChIPper;
import projects.dimont.hts.DimontTool;

public enum HTS_InitialPwms {
    DIMONT,
    INI_MOTIF,
    CYCLE_CORRECTION;


    static void compute(AbstractSingleMotifChIPper model, int motifLength, Sequence kmer, DataSet data, double[] weights, Pair<DataSet, double[][]> alternativeData, double pwmCorrectionFactor, double h, HTS_InitialPwms type) throws IllegalArgumentException, EmptyDataSetException, WrongAlphabetException, Exception {
        switch (type) {
            case DIMONT: {
                DimontTool.setMotifParameters(model, kmer, h);
                break;
            }
            case INI_MOTIF: {
                double[][] pwm = HTS_InitialPwms.getPwmFromKmer(kmer.toString(), data, weights, motifLength);
                double[] parameters = HTS_InitialPwms.pwmToInputParameters(pwm);
                model.reset();
                model.initializeHiddenUniformly();
                model.setParametersForFunction(0, parameters, 0);
                break;
            }
            case CYCLE_CORRECTION: {
                DataSet seqs = alternativeData.getFirstElement();
                double[] weights_fg = alternativeData.getSecondElement()[0];
                ArrayList<Sequence> fg = new ArrayList<Sequence>();
                ArrayList<Sequence> bg = new ArrayList<Sequence>();
                int i = 0;
                while (i < seqs.getNumberOfElements()) {
                    if (weights_fg[i] == 1.0) {
                        fg.add(seqs.getElementAt(i));
                    } else {
                        bg.add(seqs.getElementAt(i));
                    }
                    ++i;
                }
                double[] w = new double[fg.size()];
                Arrays.fill(w, 1.0);
                double[][] pwm_fg = HTS_InitialPwms.getPwmFromKmer(kmer.toString(), new DataSet("fg", fg), w, motifLength);
                w = new double[bg.size()];
                Arrays.fill(w, 1.0);
                double[][] pwm_bg = HTS_InitialPwms.getPwmFromKmer(kmer.toString(), new DataSet("bg", bg), w, motifLength);
                int col = 0;
                while (col < pwm_fg.length) {
                    int row = 0;
                    while (row < pwm_fg[0].length) {
                        pwm_fg[col][row] = pwm_fg[col][row] - pwmCorrectionFactor * pwm_bg[col][row];
                        if (pwm_fg[col][row] <= 0.0) {
                            pwm_fg[col][row] = 0.01;
                        }
                        ++row;
                    }
                    double sum = ToolBox.sum(pwm_fg[col]);
                    int row2 = 0;
                    while (row2 < pwm_fg[0].length) {
                        double[] dArray = pwm_fg[col];
                        int n = row2++;
                        dArray[n] = dArray[n] / sum;
                    }
                    ++col;
                }
                double[] parameters = HTS_InitialPwms.pwmToInputParameters(pwm_fg);
                model.reset();
                model.initializeHiddenUniformly();
                model.setParametersForFunction(0, parameters, 0);
            }
        }
    }

    private static double[] pwmToInputParameters(double[][] pwm) {
        double[] parameters = new double[pwm.length * pwm[0].length];
        int col = 0;
        while (col < pwm.length) {
            int row = 0;
            while (row < pwm[0].length) {
                parameters[col * pwm[0].length + row] = Math.log(pwm[col][row]);
                ++row;
            }
            ++col;
        }
        return parameters;
    }

    private static double[][] getPwmFromKmer(String kmer, DataSet data, double[] weights, int motifLength) throws WrongAlphabetException, OperationNotSupportedException {
        int motifStart;
        AlphabetContainer con = data.getAlphabetContainer();
        if (!con.isSimple() || !con.isDiscrete()) {
            throw new WrongAlphabetException();
        }
        DiscreteAlphabet alphabet = (DiscreteAlphabet)data.getAlphabetContainer().getAlphabetAt(0);
        Hashtable<String, String> relatedKmers = new Hashtable<String, String>(2 * kmer.length() * ((int)alphabet.length() - 1) + 2, 1.0f);
        Hashtable<String, Double> weightedKmerCounts = new Hashtable<String, Double>(kmer.length() * ((int)alphabet.length() - 1) + 1, 1.0f);
        String[][] sequences = new String[kmer.length()][(int)alphabet.length()];
        int col = 0;
        while (col < kmer.length()) {
            int row = 0;
            while ((double)row < alphabet.length()) {
                String related = String.valueOf(kmer.substring(0, col)) + alphabet.getSymbolAt(row) + kmer.substring(col + 1, kmer.length());
                String revComplSequence = Sequence.create(con, related).reverseComplement().toString();
                sequences[col][row] = related;
                relatedKmers.put(related, related);
                relatedKmers.put(revComplSequence, related);
                weightedKmerCounts.put(related, 0.0);
                ++row;
            }
            ++col;
        }
        int k = kmer.length();
        boolean kmerFound = false;
        String newKmer = "";
        String lastKmer = "";
        int n = 0;
        while (n < weights.length) {
            Sequence seq = data.getElementAt(n);
            String s = seq.toString();
            int lastStart = seq.getLength() - k;
            kmerFound = false;
            lastKmer = "";
            int startPosition = 0;
            while (startPosition <= lastStart) {
                newKmer = s.substring(startPosition, startPosition + k);
                if (relatedKmers.containsKey(newKmer)) {
                    if (lastKmer.equals("") || lastKmer.equals(relatedKmers.get(newKmer))) {
                        kmerFound = true;
                        lastKmer = (String)relatedKmers.get(newKmer);
                        startPosition += 2;
                    } else {
                        kmerFound = false;
                        break;
                    }
                }
                ++startPosition;
            }
            if (kmerFound) {
                weightedKmerCounts.put(lastKmer, (Double)weightedKmerCounts.get(lastKmer) + weights[n]);
            }
            ++n;
        }
        double[][] pwm = new double[motifLength][(int)alphabet.length()];
        int numBorderingPositions = motifLength - k;
        int col2 = motifStart = (int)Math.floor((double)numBorderingPositions / 2.0);
        while (col2 < motifStart + kmer.length()) {
            int row = 0;
            while ((double)row < alphabet.length()) {
                pwm[col2][row] = (Double)weightedKmerCounts.get(sequences[col2 - motifStart][row]) + 1.0;
                ++row;
            }
            double sum = ToolBox.sum(pwm[col2]);
            int row2 = 0;
            while ((double)row2 < alphabet.length()) {
                double[] dArray = pwm[col2];
                int n2 = row2++;
                dArray[n2] = dArray[n2] / sum;
            }
            ++col2;
        }
        col2 = 0;
        while (col2 < motifStart) {
            int row = 0;
            while ((double)row < alphabet.length()) {
                pwm[col2][row] = 0.25;
                ++row;
            }
            ++col2;
        }
        col2 = motifStart + kmer.length();
        while (col2 < motifLength) {
            int row = 0;
            while ((double)row < alphabet.length()) {
                pwm[col2][row] = 0.25;
                ++row;
            }
            ++col2;
        }
        return pwm;
    }
}

