package de.jstacs.motifDiscovery;

import java.util.AbstractList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;

import de.jstacs.classifier.utils.PValueComputation;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.sequences.PermutedSequence;
import de.jstacs.data.sequences.annotation.MotifAnnotation;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.motifDiscovery.MotifDiscoverer.KindOfProfile;
import de.jstacs.results.NumericalResult;
import de.jstacs.scoringFunctions.homogeneous.HMMScoringFunction;
import de.jstacs.utils.DoubleList;


/**
 * This class enables the user to predict motif occurrences given a specific significance level. 
 * 
 * @author Jan Grau, Jens Keilwagen
 */
public class SignificantMotifOccurrencesFinder {

	/**
	 * 
	 * @author Jan Grau, Jens Keilwagen
	 */
	public enum RandomSeqType{

		/**
		 * A enum constant that indicates to use the a background set to determine the significance level.
		 */
		BACKGROUND(-2),
		/**
		 * A enum constant that indicates to use permuted instances of the sequence to determine the significance level.
		 */
		PERMUTED(-1),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 0 to determine the significance level.
		 */
		hMM0(0),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 1 to determine the significance level.
		 */
		hMM1(1),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 3 to determine the significance level.
		 */
		hMM2(2),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 3 to determine the significance level.
		 */
		hMM3(3),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 4 to determine the significance level.
		 */
		hMM4(4),
		/**
		 * A enum constant that indicates to use sequences drawn from a homogeneous Markov model of order 5 to determine the significance level.
		 */
		hMM5(5);
		
		private final int order;
		
		RandomSeqType(int order){
			this.order = order;
		}
		
		/**
		 * This method returns the Markov order.
		 * 
		 * @return the Markov order
		 */
		public int getOrder(){
			return order;
		}
		
	};
	
	private RandomSeqType type;
	private Sample bg;
	private MotifDiscoverer disc;
	private int numSequences;
	private double sign;
	
	/**
	 * This constructor creates an instance of {@link SignificantMotifOccurrencesFinder} that uses the given {@link RandomSeqType} to determine the siginificance level.
	 * 
	 * @param disc the {@link MotifDiscoverer} for the prediction
	 * @param type the type that determines how the significance level is determined 
	 * @param numSequences the number of sampled sequence instances used to determine the significance level
	 * @param sign the significance level
	 */
	public SignificantMotifOccurrencesFinder(MotifDiscoverer disc, RandomSeqType type, int numSequences, double sign){
		this.disc = disc;
		if( type == RandomSeqType.BACKGROUND ) {
			throw new IllegalArgumentException( "This type can not be used in this constructor." );
		}
		this.type = type;
		this.numSequences = numSequences;
		this.sign = sign;
	}

	/**
	 * This constructor creates an instance of {@link SignificantMotifOccurrencesFinder} that uses a {@link Sample} to determine the siginificance level.
	 * 
	 * @param disc the {@link MotifDiscoverer} for the prediction
	 * @param bg the background data set
	 * @param sign the significance level
	 */
	public SignificantMotifOccurrencesFinder(MotifDiscoverer disc, Sample bg, double sign){
		this.disc = disc;
		this.type = RandomSeqType.BACKGROUND;
		this.numSequences = bg.getNumberOfElements();
		this.bg = bg;
		this.sign = sign;
	}
	
	private void createBgSample( Sequence seq ) {
		int order = type.getOrder();
		if( order >= 0 ){
			HMMScoringFunction hmm = new HMMScoringFunction(seq.getAlphabetContainer(),order,0,new double[order+1],true,true,1);
			try{
				hmm.initializeFunction( 0, false, new Sample[]{new Sample("",seq)}, new double[][]{{1}} );
				if( order > 0 ) {
					double[][][] condProbs = hmm.getAllConditionalStationaryDistributions();
					DoubleList list = new DoubleList( (int) (1.5*Math.pow( seq.getAlphabetContainer().getAlphabetLengthAt( 0 ), condProbs.length )) );//TODO
					for( int k, j, i = 0; i < condProbs.length; i++ ) {
						for( j = 0; j < condProbs[i].length; j++ ) {
							for( k = 0; k < condProbs[i][j].length; k++ ) {
								list.add( Math.log( condProbs[i][j][k] ) );
							}
						}
					}
					hmm.setParameters( list.toArray(), 0 );
				}
				bg = hmm.emit( numSequences, seq.getLength() );
			}catch(Exception doesnothappen){ doesnothappen.printStackTrace();}
		}
	}
	
	/**
	 * This method finds the significant motif occurrences in the sequence.
	 * 
	 * @param motif the motif index
	 * @param seq the sequence
	 * @param start the start position
	 * 
	 * @return an array of {@link MotifAnnotation} for the sequence
	 * 
	 * @throws Exception
	 */
	public MotifAnnotation[] findSignificantMotifOccurrences( int motif, Sequence seq, int start ) throws Exception{
		int[][] idxs = computeIndices( motif );
		LinkedList<MotifAnnotation> list = new LinkedList<MotifAnnotation>();
		findSignificantMotifOccurrences( motif, seq, start, idxs[0], idxs[1], list, Integer.MAX_VALUE );
		return list.toArray( new MotifAnnotation[0] );
	}
	
	private int[][] computeIndices( int motif )
	{
		int num = 0;
		for(int i=0;i<disc.getNumberOfComponents();i++){
			int loc = getLocalIndexOfMotifInComponent( i, motif );
			if(loc > -1){
				num++;
			}
		}
		int[][] idxs = new int[2][num];
		num = 0;
		for(int i=0;i<disc.getNumberOfComponents();i++){
			int loc = getLocalIndexOfMotifInComponent( i, motif );
			if(loc > -1){
				idxs[0][num] = i;
				idxs[1][num] = loc;
				num++;
			}
		}
		return idxs;
	}
	
	private void findSignificantMotifOccurrences(int motif, Sequence seq, int start, int[] idxsOfUsedComponents, int[] idxsOfMotifsInComponents, AbstractList<MotifAnnotation> list, int addMax ) throws Exception{
		int num = 0, i = 0;
		createBgSample( seq );	
		Sequence permSeq = null;
		LinkedList<double[][]> scoreList = new LinkedList<double[][]>();
		double[][] temp = null;
		for(;i<numSequences;i++){
			if(type == RandomSeqType.PERMUTED){
				permSeq = new PermutedSequence(seq);
			}else{
				permSeq = bg.getElementAt( i );
			}
			temp = new double[idxsOfUsedComponents.length][];
			for(int j=0;j<idxsOfUsedComponents.length;j++){
				temp[j] = disc.getProfileOfScoresFor( idxsOfUsedComponents[j], idxsOfMotifsInComponents[j], permSeq, start, KindOfProfile.UNNORMALIZED_JOINT );
				num += temp[j].length;
			}
			scoreList.add( temp );
		}
		
		double[] scores = new double[num];
		Iterator<double[][]> it = scoreList.iterator();
		num = 0;
		while(it.hasNext()){
			temp = it.next();
			for(i=0;i<temp.length;i++){
				System.arraycopy( temp[i], 0, scores, num, temp[i].length );
				num += temp[i].length;
			}
		}
		Arrays.sort( scores );
		
		//System.out.println( scores[0] + " .. " + scores[scores.length-1] );
		
		temp = new double[idxsOfUsedComponents.length][];
		for(i=0;i<idxsOfUsedComponents.length;i++){
			temp[i] = disc.getProfileOfScoresFor( idxsOfUsedComponents[i], idxsOfMotifsInComponents[i], seq, start, KindOfProfile.UNNORMALIZED_JOINT );
		}

		int signIndex = PValueComputation.getBorder( scores, sign );
		double thresh = PValueComputation.getThreshold( scores, signIndex ), pVal;
		int length = disc.getMotifLength( motif ), listIndex = list.size();
		DoubleList pValues = new DoubleList();
		for(i=0;i<temp.length;i++){
			for(int j=0;j<temp[i].length;j++){
				if(temp[i][j] > thresh){
					pVal = PValueComputation.getPValue( scores, temp[i][j], signIndex );
					pValues.add( pVal );
					list.add( new MotifAnnotation( "motif* " + motif, j+start, length,
							disc.getStrandFor( idxsOfUsedComponents[i], idxsOfMotifsInComponents[i], seq, j+start ),
							new NumericalResult( "component", "the component of the model where this motif was found", idxsOfUsedComponents[i] ),
							new NumericalResult( "p-value", "", pVal ) ) );
				}
			}
		}
		
		if( pValues.length() > addMax ) {
			//reduce prediction
			double[] array = pValues.toArray();
			Arrays.sort( array );
			
			//System.out.println( Arrays.toString( array ) );
			//System.out.println( array[addMax]);
			
			i = 0;
			while( i < pValues.length() ) {
				if( pValues.get( i ) >= array[addMax] ) {
					list.remove( listIndex );
				} else {
					listIndex++;
				}
				i++;
			}
		}
	}
	
	/**
	 * This method annotates a {@link Sample} starting in each sequence at <code>startPos</code>.
	 * 
	 * @param startPos the start position used for all sequences
	 * @param data the {@link Sample}
	 * 
	 * @return an annotated {@link Sample}
	 * 
	 * @throws Exception if something went wrong
	 * 
	 * @see SignificantMotifOccurrencesFinder#annotateMotifs(int, Sample, int)
	 */
	public Sample annotateMotifs( int startPos, Sample data ) throws Exception
	{
		return annotateMotifs( startPos, data, Integer.MAX_VALUE );
	}

	/**
	 * This method annotates a {@link Sample}.
	 * 
	 * @param data the {@link Sample}
	 * 
	 * @return an annotated {@link Sample}
	 * 
	 * @throws Exception if something went wrong
	 * 
	 * @see SignificantMotifOccurrencesFinder#annotateMotifs(int, Sample, int)
	 */
	public Sample annotateMotifs( Sample data ) throws Exception
	{
		return annotateMotifs( 0, data, Integer.MAX_VALUE );
	}
	
	/**
	 * This method annotates a {@link Sample}.
	 * At most, <code>addMax</code> motif occurrences of each motif instance will be annotated.
	 * 
	 * @param data the {@link Sample}
	 * @param addMax the number of motif occurrences that can at most be annotated for each motif instance
	 * 
	 * @return an annotated {@link Sample}
	 * 
	 * @throws Exception if something went wrong
	 * 
	 * @see SignificantMotifOccurrencesFinder#annotateMotifs(int, Sample, int)
	 */
	public Sample annotateMotifs( Sample data, int addMax ) throws Exception
	{
		return annotateMotifs( 0, data, addMax );
	}
	
	/**
	 * This method annotates a {@link Sample} starting in each sequence at <code>startPos</code>.
	 * At most, <code>addMax</code> motif occurrences of each motif instance will be annotated.
	 * 
	 * @param startPos the start position used for all sequences
	 * @param data the {@link Sample}
	 * @param addMax the number of motif occurrences that can at most be annotated for each motif instance
	 * 
	 * @return an annotated {@link Sample}
	 * 
	 * @throws Exception if something went wrong
	 * 
	 * @see SignificantMotifOccurrencesFinder#annotateMotifs(int, Sample, int)
	 */
	public Sample annotateMotifs( int startPos, Sample data, int addMax ) throws Exception
	{
		int i, n = data.getNumberOfElements(), j, m = disc.getNumberOfMotifs();
		int[][][] idxs = new int[m][][];
		Sequence[] seqs = new Sequence[n];
		for( j = 0; j < m; j++ )
		{
			idxs[j] = computeIndices( j );
		}
		LinkedList<MotifAnnotation> seqAn = new LinkedList<MotifAnnotation>();
		SequenceAnnotation[] empty = new SequenceAnnotation[0];
		for( i = 0; i < n; i++ )
		{
			seqs[i] = data.getElementAt(i);
			
			//collect annotation in seqAn
			seqAn.clear();
			for( j = 0; j < m; j++ )
			{
				findSignificantMotifOccurrences( j, seqs[i], startPos, idxs[j][0], idxs[j][1], seqAn, addMax );
			}
			
			//replace annotation with those currently computed
			seqs[i] = seqs[i].annotate( false, seqAn.toArray( empty ) );
		}
		return new Sample( "annotated", seqs );
	}
	
	private int getLocalIndexOfMotifInComponent(int component, int motif){
		for(int i=0;i<disc.getNumberOfMotifsInComponent( component );i++){
			if(disc.getGlobalIndexOfMotifInComponent( component, i ) == motif){
				return i;
			}
		}
		return -1;
	}
	
	/**
	 * This method determines the p-value for each symbol to be annotated at least in one motif occurrence of the
	 * motif with index <code>index</code> in the component <code>component</code>.
	 * 
	 * @param data the {@link Sample}
	 * @param component the component index
	 * @param motif the motif index
	 * @param addOnlyBest a switch whether to add only the best
	 * 
	 * @return an array containing for each sequence an array with the <code>p</code>-value for each symbol in the sequence
	 * 
	 * @throws Exception if something went wrong
	 * 
	 * @see MotifDiscoveryAssessment#getSorted1MinusPValuesForMotifAndFlanking(Sample, double[][], String)
	 */
	public double[][] getPValuesForEachNucleotide( Sample data, int component, int motif, boolean addOnlyBest ) throws Exception {
		double[][] res = new double[data.getNumberOfElements()][];
		for( int i = 0; i < res.length; i++ ) {
			res[i] = getPValueForNucleotides(data.getElementAt(i), 0, component, motif, addOnlyBest);
		}
		return res;
	}
	
	////TODO make this more general???
	private double[] getPValueForNucleotides(Sequence seq, int start, int component, int motif, boolean addOnlyBest ) throws Exception{
		
		createBgSample( seq );
		Sequence bgSeq = null;
		LinkedList<double[]> scoreList = new LinkedList<double[]>();
		double[] temp = null;
		int i = 0, j, k, num = 0;
		for( ;i<numSequences;i++){
			if(type == RandomSeqType.PERMUTED){
				bgSeq = new PermutedSequence(seq);
			}else{
				bgSeq = bg.getElementAt( i );
			}
			temp = disc.getProfileOfScoresFor( component, motif, bgSeq, start, KindOfProfile.UNNORMALIZED_JOINT );
			num += temp.length;
			scoreList.add( temp );
		}
		
		double[] scores = new double[num];
		Iterator<double[]> it = scoreList.iterator();
		num = 0;
		while(it.hasNext()){
			temp = it.next();
			System.arraycopy( temp, 0, scores, num, temp.length );
			num += temp.length;
		}		
		Arrays.sort( scores );
		
		temp = disc.getProfileOfScoresFor( component, motif, seq, start, KindOfProfile.UNNORMALIZED_JOINT );

		//naive approach
		double[] res = new double[seq.getLength()-start];
		Arrays.fill( res, 1 );
		
		int length = disc.getMotifLength( motif );
		if( addOnlyBest ) {
			int idx = getIndexOfMax( temp );
			double p = PValueComputation.getPValue( scores, temp[idx] );
			for( i = 0; i < length; i++ ){
				res[idx+i] = p;
			}
		} else {
			for(i=0;i<temp.length;i++){
				res[i] = PValueComputation.getPValue( scores, temp[i] );
			}
			for(i=res.length-1;i >= 0;i--){
				for(k=i-1,j=1; j < length && k >= 0; j++, k--){
					if(res[i] > res[k] ) {
						res[i] = res[k];
					}
				}
			}
		}
		return res;
	}
	
	private static int getIndexOfMax( double... values ) {
		int idx = 0, i = 1;
		for( ; i < values.length; i++ ) {
			if( values[i] > values[idx] ) {
				idx = i;
			}
		}
		return idx;
	}
}
