package de.jstacs.scoringFunctions.directedGraphicalModels;

import de.jstacs.NonParsableException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.Measure;
import de.jstacs.scoringFunctions.mix.motifSearch.DurationScoringFunction;

/**
 * This class implements a {@link NormalizableScoringFunction} for an inhomogeneous Markov model.
 * The modeled length can be modified which might be very important for de-novo motif discovery.
 * 
 * @author Jan Grau, Jens Keilwagen
 */
public class MutableMarkovModelScoringFunction extends BayesianNetworkScoringFunction implements Mutable{

	private static int numSamples = 10000;
	
	private DurationScoringFunction lengthPenalty;

	/**
	 * This constructor creates an instance with an prior for the modeled length.
	 * 
	 * @param alphabet the {@link AlphabetContainer} of the {@link MutableMarkovModelScoringFunction}
	 * @param length the initial length of the modeled sequences
	 * @param ess the equivalent sample size
	 * @param plugInParameters a switch whether to use plug-in parameters of not
	 * @param order the order of the Markov model
	 * @param lengthPenalty the prior on the modeled sequence length
	 * 
	 * @throws Exception if super class constructor throws an {@link Exception} or if the <code>lengthPenalty</code> does not allow the initial length
	 */
	public MutableMarkovModelScoringFunction(AlphabetContainer alphabet,
			int length, double ess, boolean plugInParameters,
			int order, DurationScoringFunction lengthPenalty ) throws Exception {
		this( alphabet, length, ess, plugInParameters, new InhomogeneousMarkov( order ), lengthPenalty );
	}
	
	/**
	 * This constructor creates an instance without any prior for the modeled length.
	 * 
	 * @param alphabet the {@link AlphabetContainer} of the {@link MutableMarkovModelScoringFunction}
	 * @param length the initial length of the modeled sequences
	 * @param ess the equivalent sample size
	 * @param plugInParameters a switch whether to use plug-in parameters of not
	 * @param structureMeasure a {@link Measure} for the structure
	 * 
	 * @throws Exception if super class constructor throws an {@link Exception}
	 */
	public MutableMarkovModelScoringFunction( AlphabetContainer alphabet, int length, double ess, boolean plugInParameters,
			InhomogeneousMarkov structureMeasure ) throws Exception {
		super( alphabet, length, ess, plugInParameters, structureMeasure );
	}

	/**
	 * This constructor creates an instance with an prior for the modeled length.
	 * 
	 * @param alphabet the {@link AlphabetContainer} of the {@link MutableMarkovModelScoringFunction}
	 * @param length the initial length of the modeled sequences
	 * @param ess the equivalent sample size
	 * @param plugInParameters a switch whether to use plug-in parameters of not
	 * @param structureMeasure a {@link Measure} for the structure
	 * @param lengthPenalty the prior on the modeled sequence length
	 * 
	 * @throws Exception if super class constructor throws an {@link Exception} or if the <code>lengthPenalty</code> does not allow the initial length
	 */
	public MutableMarkovModelScoringFunction( AlphabetContainer alphabet, int length, double ess, boolean plugInParameters,
			InhomogeneousMarkov structureMeasure, DurationScoringFunction lengthPenalty ) throws Exception {
		this( alphabet, length, ess, plugInParameters, structureMeasure );
		this.lengthPenalty = lengthPenalty;
		if( lengthPenalty!= null && !lengthPenalty.isPossible( length ) ) {
			throw new IllegalArgumentException( "This motif length is not possible: " + length );
		}
	}

	/**
	 * The standard constructor for the interface {@link de.jstacs.Storable}.
	 * Recreates a {@link MutableMarkovModelScoringFunction} from its XML
	 * representation as saved by the method {@link #toXML()}.
	 * 
	 * @param xml
	 *            the XML representation as {@link StringBuffer}
	 * 
	 * @throws NonParsableException
	 *             if the XML code could not be parsed
	 */
	public MutableMarkovModelScoringFunction( StringBuffer xml ) throws NonParsableException {
		super( xml );
	}	
	
	private static final String XML_TAG = "MutableMarkovModelScoringFunction";
	
	protected void fromXML(StringBuffer source) throws NonParsableException {
		StringBuffer sb = XMLParser.extractForTag( source, XML_TAG );
		lengthPenalty = (DurationScoringFunction) XMLParser.extractStorableOrNullForTag( sb, "lengthPenalty" );
		super.fromXML( sb );
	}
	
	public StringBuffer toXML() {
		StringBuffer sb = super.toXML();
		XMLParser.appendStorableOrNullWithTags( sb, lengthPenalty, "lengthPenalty" );
		XMLParser.addTags( sb, XML_TAG );
		return sb;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunction#getLogPriorTerm()
	 */
	@Override
	public double getLogPriorTerm() {
		if(lengthPenalty != null){
			return super.getLogPriorTerm() + lengthPenalty.getLogScore( length );
		}else{
			return super.getLogPriorTerm();
		}
	}

	public boolean modify( double[] weightsLeft, double[] weightsRight, double[][][][] fillEmptyWithLeft, double[][][][] fillEmptyWithRight, int offsetLeft, int offsetRight ) {
		
		double[][][] fillEmptyWithLeftSel = null;
		double best = Double.NEGATIVE_INFINITY;
		for(int i=0;i<fillEmptyWithLeft.length;i++){
			if(weightsLeft[i] > best){
				fillEmptyWithLeftSel = fillEmptyWithLeft[i];
				best = weightsLeft[i];
			}
		}
		
		double[][][] fillEmptyWithRightSel = null;
		best = Double.NEGATIVE_INFINITY;
		for(int i=0;i<fillEmptyWithRight.length;i++){
			if(weightsRight[i] > best){
				fillEmptyWithRightSel = fillEmptyWithRight[i];
				best = weightsRight[i];
			}
		}
		
		this.precomputeNormalization();
		this.normalizeParameters();
		this.precomputeNormalization();
		if(! getAlphabetContainer().isSimple() ){
			return false;
		}
		if(offsetLeft == 0 && offsetRight == 0){
			return true;
		}else{
			ParameterTree[] backTrees = trees;
			int backLength = this.getLength();
			try{
				this.length = this.getLength() - offsetLeft + offsetRight;
				if( lengthPenalty != null && !lengthPenalty.isPossible( length ) ) {
					throw new IllegalArgumentException( "This motif length is not possible: " + length );
				}
				this.createTrees( new Sample[]{null,null}, new double[][]{null,null} );
				int indexNew = 0, indexOld = 0;

				//left side
				if( offsetLeft >= 0 )
				{
					indexOld = offsetLeft;
				}
				else
				{
					for( ;indexNew<-offsetLeft && indexNew < trees.length; indexNew++){
						trees[indexNew].fill( fillEmptyWithLeftSel );
					}
				}

				//copy
				for( ;indexOld<backTrees.length && indexNew<trees.length; indexNew++, indexOld++ ){
					trees[indexNew].copy( backTrees[indexOld] );
				}

				//right side
				for( ; indexNew<trees.length; indexNew++){
					trees[indexNew].fill( fillEmptyWithRightSel );
				}
				normalizationConstant = null;
				return true;
			}catch(Exception e){
				this.length = backLength;
				trees = backTrees;
				normalizationConstant = null;
				return false;
			}
		}
	}

	private void normalizeParameters() {
		for(int i=0;i<trees.length;i++){
			trees[i].normalizeParameters();
		}
	}

	

	private double klDivergence(double[] p, double[] q){
		double kl = 0;
		for(int i=0;i<p.length;i++){
			kl += p[i]*Math.log(p[i]/q[i]);
		}
		return kl;
	}

	public int[] determineNotSignificantPositions( double samples, double[] weightsLeft, double[] weightsRight, double[][][][] contrastLeft, double[][][][] contrastRight, double sign){
		this.precomputeNormalization();
		this.normalizeParameters();
		this.precomputeNormalization();
		
		int numLeft = 0, c = 0;
		boolean insign = true;
		while(c < trees.length /*&& insign*/){//TODO JAN
			insign = !testSignificance(samples, c, weightsLeft, contrastLeft, sign);
			if(c==numLeft && insign){
				numLeft++;
			}
			c++;
		}
		System.out.println("-----------------------------");
		int numRight = 0;
		insign = true;
		c = 0;
		while(c < trees.length /*&& insign*/){ //TODO JAN
			insign = !testSignificance(samples, trees.length-c-1,weightsRight, contrastRight, sign);
			if(c==numRight &&insign){
				numRight++;
			}
			c++;
		}
		return new int[]{numLeft, numRight};
	}
	
	private boolean testSignificance( double samples, int treeIdx, double[] weights, double[][][][] contrast, double sign ) {
		double kl = 0, e = ess /*+samples/**/;//TODO
		
		//Variante 1
		kl = trees[treeIdx].getKLDivergence( weights, contrast );
		double[] kls = new double[numSamples];
		trees[treeIdx].drawKLDivergences( kls, weights, contrast, e );
		/*
		//Variante 2
		for(int i=0;i<weights.length;i++){
			kl += weights[i] * trees[treeIdx].getKLDivergence( contrast[i] );
		}
		double[] kls = new double[numSamples];
		int num = 0;
		for(int i=0;i<weights.length;i++){
			trees[treeIdx].drawKLDivergences(weights[i],kls,0,kls.length,contrast[i], e);
		}
		//trees[treeIdx].drawKLDivergences(kls,num,kls.length,contrast[weights.length-1], e);
		*/
		System.out.print( treeIdx + " " );
		return testSign( kls, kl, sign );
	}

	/*public int[] determineNotSignificantPositionsOld(double[][][] contrastLeft, double[][][] contrastRight, double sign){

		int numSamples = 10000;

		double[][] kl = new double[2][trees.length];
		for(int i=0;i<trees.length;i++){
			kl[0][i] = trees[i].getKLDivergence( contrastLeft );
			kl[1][i] = trees[i].getKLDivergence( contrastRight );
		}

		double[] alphaLeft = new double[contrastLeft.length];
		double[] alphaRight = new double[contrastRight.length];
		for(int i=0;i<alphaLeft.length;i++){
			alphaLeft[i] = this.ess*contrastLeft[i];
			alphaRight[i] = this.ess*contrastRight[i];
		}

		DirichletMRGParams pparsl = new DirichletMRGParams(alphaLeft);
		DirichletMRGParams pparsr = new DirichletMRGParams(alphaRight);
		double[] klsl = new double[numSamples];
		double[] klsr = new double[numSamples];
		double[] pvall = new double[alphaLeft.length], pvalr = new double[alphaRight.length];
		for(int i=0;i<numSamples;i++){
			DirichletMRG.DEFAULT_INSTANCE.generate(pvall,0,alphaLeft.length, pparsl);
			DirichletMRG.DEFAULT_INSTANCE.generate(pvalr,0,alphaRight.length, pparsr);
			klsl[i] = klDivergence(pvall, contrastLeft);
			klsr[i] = klDivergence(pvalr, contrastRight);
		}
		int i=0;
		while(i<kl[0].length && !testSign(klsl,kl[0][i],sign)) i++;
		System.out.println("shift left: "+i);

		int j=kl[1].length-1;
		while(j >= 0 && !testSign(klsr,kl[1][j],sign)) j--;

		System.out.println("shift right: "+j);

		return new int[]{i,-kl[1].length + 1 + j};

	}*/

	private boolean testSign(double[] kls, double kl, double sign) {
		double num = 0;
		for(int i=0;i<kls.length;i++){
			if(kls[i] >= kl){
				num ++;
			}
		}
		System.out.println("kl: "+kl+": "+(num/(double)kls.length)+" -> "+(num/(double)kls.length < sign));
		return num/(double)kls.length < sign;
	}
}
