Recipes

From Jstacs

Jump to: navigation, search

In this section, we show some more complex code examples. All these code examples can be downloaded as a zip file and may serve as a starting points for your own applications.

Contents


Creation of user-specfic alphabet

In this example, we create a new ComplementableDiscreteAlphabet using the generic implementation. We then use this Alphabet to create a Sequence and compute its reverse complement.


String[] symbols = {"A", "C", "G", "T", "-"};
//new alphabet
DiscreteAlphabet abc = new DiscreteAlphabet( true, symbols );

//new alphabet that allows to build the reverse complement of a sequence
int[] revComp = new int[symbols.length];
revComp[0] = 3; //symbols[0]^rc = symbols[3]
revComp[1] = 2; //symbols[1]^rc = symbols[2]
revComp[2] = 1; //symbols[2]^rc = symbols[1]
revComp[3] = 0; //symbols[3]^rc = symbols[0]
revComp[4] = 4; //symbols[4]^rc = symbols[4]
GenericComplementableDiscreteAlphabet abc2 = new GenericComplementableDiscreteAlphabet( true, symbols, revComp );

Sequence seq = Sequence.create( new AlphabetContainer( abc2 ), "ACGT-" );
Sequence rc = seq.reverseComplement();
System.out.println( seq );
System.out.println( rc );


Learning a position weight matrix from data

In this example, we show how to load sequence data into Jstacs and how to learn a position weight matrix (inhomogeneous Markov model of order 0) on these data.


//read data from FastA file
DNADataSet ds = new DNADataSet( args[0] );
AlphabetContainer con = ds.getAlphabetContainer();
//create position weight matrix model
TrainableStatisticalModel pwm = TrainableStatisticalModelFactory.createPWM( con, ds.getElementLength(), 4 );
//train it on the input data
pwm.train( ds );
//print the trained model
System.out.println(pwm);


Learning a homogeneous Markov model from data

In this example, we show how to load sequence data into Jstacs and how to learn a homogeneous Markov model of order 1 on these data.


//read data from FastA file
DNADataSet ds = new DNADataSet( args[0] );
//create homogeneous Markov model of order 1
TrainableStatisticalModel hmm = TrainableStatisticalModelFactory.createHomogeneousMarkovModel( ds.getAlphabetContainer(), 4, (byte)1 );
//train it on the input data
hmm.train( ds );
//print the trained model
System.out.println(hmm);


Generating data from a homogeneous Markov model

In this example, we show how to learn a homogeneous Markov model of order 2 from data (similar to the previous example), and use the learned model to generate new data following the same distribution as the original data.


//read data from FastA file
DNADataSet ds = new DNADataSet( args[0] );
//create homogeneous Markov model of order 2
TrainableStatisticalModel hmm = TrainableStatisticalModelFactory.createHomogeneousMarkovModel( ds.getAlphabetContainer(), 4, (byte)2 );
//train it on the input data
hmm.train( ds );

//generate 100 sequences of length 20
DataSet generated = hmm.emitDataSet( 100, 20 );
//print these data
System.out.println(generated);
//and save them to a plain text file
generated.save( new File(args[1]) );


Learning a mixture model from data

In this example, we show how to load sequence data into Jstacs and how to learn a mixture model of two position weight matrices on these data using the expectation maximization algorithm.


//read data from FastA file
DNADataSet ds = new DNADataSet( args[0] );
//create position weight matrix model
TrainableStatisticalModel pwm = TrainableStatisticalModelFactory.createPWM( ds.getAlphabetContainer(), ds.getElementLength(), 4 );
//create mixture model of two position weight matrices
TrainableStatisticalModel mixture = TrainableStatisticalModelFactory.createMixtureModel( new double[]{4,4}, new TrainableStatisticalModel[]{pwm,pwm} );
//train it on the input data using EM
mixture.train( ds );
//print the trained model
System.out.println(mixture);


Analysing data with different models

In this example, we show how to use the TrainableStatisticalModelFactory to create inhomogeneous and homogeneous Markov models, and Bayesian trees, and how to learn these models on a common data set.


//read data from FastA file
DNADataSet ds = new DNADataSet( args[0] );

//get alphabet, length from data
AlphabetContainer alphabet = ds.getAlphabetContainer();
int length = ds.getElementLength();
//set ESS used for all models
double ess = 4;

TrainableStatisticalModel[] models = new TrainableStatisticalModel[4];
//create position weight matrix
models[0] = TrainableStatisticalModelFactory.createPWM( alphabet, length, 4 );
//create inhomogeneous Markov model of order 1 (WAM)
models[1] = TrainableStatisticalModelFactory.createInhomogeneousMarkovModel( alphabet, length, ess, (byte)1 );
//create Bayesian tree
models[2] = TrainableStatisticalModelFactory.createBayesianNetworkModel( alphabet, length, ess, (byte)1 );
//create homogeneous Markov model of order 2
models[3] = TrainableStatisticalModelFactory.createHomogeneousMarkovModel( alphabet, ess, (byte)2 );

//train and print all models
for(int i=0;i<models.length;i++){
        models[i].train( ds );
        System.out.println(models[i]);
}


De-novo motif discovery with a sunflower hidden Markov model)

In this example, we show how to use the HMMFactory to create a sunflower hidden Markov model (HMM) with two motifs of different lengths. We show how to train the sunflower HMM on input data, which are typically long sequences containing an over-represented motif. After training the HMM, we show how to compute and output the Viterbi paths for all sequences, which give an indication of the position of motif occurrences.


//load data
DataSet data = new DNADataSet(args[0]);
//define parameters of Baum-Welch training using all available processor cores
BaumWelchParameterSet pars = new BaumWelchParameterSet(10, new SmallDifferenceOfFunctionEvaluationsCondition(1E-6), AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors());
//create sunflower HMM with motifs of length 8 and 12
AbstractHMM hmm = HMMFactory.createSunflowerHMM(pars, data.getAlphabetContainer(), 0, data.getElementLength(), true, 8,12);
//train the HMM using Baum-Welch
hmm.train(data);
//print the trained HMM
System.out.println(hmm);
//print Viterbi paths of all sequences
for(int i=0;i<data.getNumberOfElements();i++){
        Pair<IntList,Double> p = hmm.getViterbiPathFor(data.getElementAt(i));
        System.out.println(p.getSecondElement()+"\t"+p.getFirstElement());
}


Learning a classifier using the generative maximum a-posteriori principle

In this example, we show how to train a classifier based on a position weight matrix model and a homogeneous Markov model on training data, and how to use the trained classifier to classify sequences.


//read data from FastA files
DataSet[] data = new DataSet[2];
data[0] = new DNADataSet( args[0] );
data[1] = new DNADataSet( args[1] );

//create position weight matrix model
TrainableStatisticalModel pwm = TrainableStatisticalModelFactory.createPWM( data[0].getAlphabetContainer(), data[0].getElementLength(), 4 );
//create homogeneous Markov model of order 1
TrainableStatisticalModel hmm = TrainableStatisticalModelFactory.createHomogeneousMarkovModel( data[1].getAlphabetContainer(), 4, (byte)1 );
//build a classifier using these models
TrainSMBasedClassifier cl = new TrainSMBasedClassifier( pwm, hmm );
//train it on the training data
cl.train( data );

//print the trained classifier
System.out.println(cl);
//classify one of the sequences
Sequence seq = data[0].getElementAt( 0 );
byte res = cl.classify( seq );
//print sequence and classification result
System.out.println(seq+" -> "+res);

//evaluate
NumericalPerformanceMeasureParameterSet params = PerformanceMeasureParameterSet.createFilledParameters();
System.out.println( cl.evaluate(params, true, data) );


Learning a classifier using the discriminative maximum supervised posterior principle

In this example, we show how to use the DifferentiableStatisticalModelFactory to create a position weight matrix and how to learn a classifier based on two position weight matrices using the discriminative maximum supervised posterior principle.


//read data from FastA files
DataSet[] data = new DataSet[2];
data[0] = new DNADataSet( args[0] );
data[1] = new DNADataSet( args[1] );
AlphabetContainer con = data[0].getAlphabetContainer();

//define differentiable PWM model
DifferentiableStatisticalModel pwm = DifferentiableStatisticalModelFactory.createPWM(con, 10, 4);
//parameters for numerical optimization
GenDisMixClassifierParameterSet pars = new GenDisMixClassifierParameterSet(con,10,(byte)10,1E-9,1E-10,1, false,KindOfParameter.PLUGIN,true,1);
//define and train classifier
AbstractClassifier cl = new MSPClassifier( pars, pwm, pwm );
cl.train( data );

System.out.println(cl);


Creating Data sets

In this example, we show different ways of creating a DataSet in Jstacs from plain text and FastA files and using the adaptor to BioJava.


String home = args[0]+File.separator;

//load DNA sequences in FastA-format
DataSet data = new DNADataSet( home+"myfile.fa" );

//create a DNA-alphabet
AlphabetContainer container = DNAAlphabetContainer.SINGLETON;

//create a DataSet using the alphabet from above in FastA-format
data = new DataSet( container, new SparseStringExtractor( home+"myfile.fa", StringExtractor.FASTA ));

//create a DataSet using the alphabet from above
data = new DataSet( container, new SparseStringExtractor( home+"myfile.txt" ));

//defining the ids, we want to obtain from NCBI Genbank:
GenbankSequenceDB db = new GenbankSequenceDB();

//at the moment the following fails due to a problem in BioJava hopefully fixed in the next legacy release
//this may fail if BioJava fails to load the sequence, e.g. if you are not connected to the internet
/*SimpleSequenceIterator it = new SimpleSequenceIterator(
                db.getSequence( "NC_001284.2" ),
                db.getSequence( "NC_000932.1" )
                );
 */


RichSequenceIterator it = IOTools.readGenbankDNA( new BufferedReader( new FileReader( home+"example.gb" ) ), null );

//conversion to Jstacs DataSet
data = BioJavaAdapter.sequenceIteratorToDataSet( it, null );
System.out.println(data);


Using TrainSMBasedClassifier

In this example, we show how to create a TrainSMBasedClassifier using to position weight matrices, train this classifier, classify previously unlabeled data, store the classifier to its XML representation, and load it back into Jstacs.


String home = args[0];

//create a DataSet for each class from the input data, using the DNA alphabet
DataSet[] data = new DataSet[2];
data[0] = new DNADataSet( args[1] );

//the length of our input sequences
int length = data[0].getElementLength();

data[1] = new DataSet( new DNADataSet( args[2] ), length );
 
//create a new PWM
BayesianNetworkTrainSM pwm = new BayesianNetworkTrainSM( new BayesianNetworkTrainSMParameterSet(
                //the alphabet and the length of the model:
                data[0].getAlphabetContainer(), length,
                //the equivalent sample size to compute hyper-parameters
                4,
                //some identifier for the model
                "my PWM",
                //we want a PWM, which is an inhomogeneous Markov model (IMM) of order 0
                ModelType.IMM, (byte) 0,
                //we want to estimate the MAP-parameters
                LearningType.ML_OR_MAP ) );
 
//create a new classifier
TrainSMBasedClassifier classifier = new TrainSMBasedClassifier( pwm, pwm );
//train the classifier
classifier.train( data );
//sequences that will be classified
DataSet toClassify = new DNADataSet( args[3] );
 
//use the trained classifier to classify new sequences
byte[] result = classifier.classify( toClassify );
System.out.println( Arrays.toString( result ) );
 
//create the XML-representation of the classifier
StringBuffer buf = new StringBuffer();
XMLParser.appendObjectWithTags( buf, classifier, "classifier" );
 
//write it to disk
FileManager.writeFile( new File(home+"myClassifier.xml"), buf );

//read XML-representation from disk
StringBuffer buf2 = FileManager.readFile( new File(home+"myClassifier.xml") );
 
//create new classifier from read StringBuffer containing XML-code
AbstractClassifier trainedClassifier = (AbstractClassifier) XMLParser.extractObjectForTags(buf2, "classifier");


Using GenDisMixClassifier

In this example, we show how to create GenDisMixClassifier s using two position weight matrices. We show how GenDisMixClassifier s can be created for all basic learning principles (ML, MAP, MCL, MSP), and how these classifiers can be trained and assessed.


//read FastA-files
DataSet[] data = {
         new DNADataSet( args[0] ),
         new DNADataSet( args[1] )
};
AlphabetContainer container = data[0].getAlphabetContainer();
int length = data[0].getElementLength();

//equivalent sample size =^= ESS
double essFg = 4, essBg = 4;
//create DifferentiableSequenceScore, here PWM
DifferentiableStatisticalModel pwmFg = new BayesianNetworkDiffSM( container, length, essFg, true, new InhomogeneousMarkov(0) );
DifferentiableStatisticalModel pwmBg = new BayesianNetworkDiffSM( container, length, essBg, true, new InhomogeneousMarkov(0) );

//create parameters of the classifier
GenDisMixClassifierParameterSet cps = new GenDisMixClassifierParameterSet(
                container,//the used alphabets
                length,//sequence length that can be modeled/classified
                Optimizer.QUASI_NEWTON_BFGS, 1E-1, 1E-1, 1,//optimization parameter
                false,//use free parameters or all
                KindOfParameter.PLUGIN,//how to start the numerical optimization
                true,//use a normalized objective function
                AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors()//number of compute threads           
);

//create classifiers
LearningPrinciple[] lp = LearningPrinciple.values();
GenDisMixClassifier[] cl = new GenDisMixClassifier[lp.length+1];
//elementary learning principles
int i = 0;
for( ; i < cl.length-1; i++ ){
        System.out.println( "classifier " + i + " uses " + lp[i] );
        cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), lp[i], pwmFg, pwmBg );
}

//use some weighted version of log conditional likelihood, log likelihood, and log prior
double[] beta = {0.3,0.3,0.4};
System.out.println( "classifier " + i + " uses the weights " + Arrays.toString( beta ) );
cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), beta, pwmFg, pwmBg );

//do what ever you like

//e.g., train
for( i = 0; i < cl.length; i++ ){
        cl[i].train( data );
}

//e.g., evaluate (normally done on a test data set)
PerformanceMeasureParameterSet mp = PerformanceMeasureParameterSet.createFilledParameters();
for( i = 0; i < cl.length; i++ ){
        System.out.println( cl[i].evaluate( mp, true, data ) );
}


Accessing R from Jstacs

Here, we show a number of examples how R can be used from within Jstacs using RServe.


REnvironment e = null;
try {
        //create a connection to R with YOUR server name, login and password
        e = new REnvironment();//might be adjusted
 
        System.out.println( "java: " + System.getProperty( "java.version" ) );
        System.out.println();
        System.out.println( e.getVersionInformation() );
 
        // compute something in R
        REXP erg = e.eval( "sin(10)" );
        System.out.println( erg.asDouble() );
 
        //create a histrgram in R in 3 steps
        //1) create the data
        e.voidEval( "a = 100;" );
        e.voidEval( "n = rnorm(a)" );
        //2) create the plot command
        String plotCmd = "hist(n,breaks=a/5)";
        //3a) plot as pdf
        e.plotToPDF( plotCmd, args[0]+"/test.pdf", true );
        //or
        //3b) create an image and show it
        BufferedImage i = e.plot( plotCmd, 640, 480 );
        REnvironment.showImage( "histogramm", i, JFrame.EXIT_ON_CLOSE );
 
} catch ( Exception ex ) {
        ex.printStackTrace();
} finally {
        if( e != null ) {
                try {
                        //close REnvironment correctly
                        e.close();
                } catch ( Exception e1 ) {
                        System.err.println( "could not close REnvironment." );
                        e1.printStackTrace();
                }
        }
}


Getting ROC and PR curve from a classifier

In this example, we show how a classifier (loaded from disk) can be assessed on test data, and how we can plot ROC and PR curves of this classifier and test data set.


public static void main(String[] args) throws Exception {
        //read XML-representation from disk
        StringBuffer buf2 = FileManager.readFile( new File(args[0]+"myClassifier.xml") );
         
        //create new classifier from read StringBuffer containing XML-code
        AbstractClassifier trainedClassifier = (AbstractClassifier) XMLParser.extractObjectForTags(buf2, "classifier");

        //create a DataSet for each class from the input data, using the DNA alphabet
        DataSet[] test = new DataSet[2];
        test[0] = new DNADataSet( args[1] );
       
        //the length of our input sequences
        int length = test[0].getElementLength();

        test[1] = new DataSet( new DNADataSet( args[2] ), length );
       
         
        AbstractPerformanceMeasure[] m = { new PRCurve(), new ROCCurve() };
        PerformanceMeasureParameterSet mp = new PerformanceMeasureParameterSet( m );
        ResultSet rs = trainedClassifier.evaluate( mp, true, test );
         
        REnvironment r = null;
        try {
                r = new REnvironment();//might be adjusted
                for( int i = 0; i < rs.getNumberOfResults(); i++ )  {
                        Result res = rs.getResultAt(i);
                        if( res instanceof DoubleTableResult ) {
                                DoubleTableResult dtr = (DoubleTableResult) res;
                                ImageResult ir = DoubleTableResult.plot( r, dtr );
                                REnvironment.showImage( dtr.getName(), ir.getValue() );
                        } else {
                                System.out.println( res );
                        }
                }
        } catch( Exception e ) {
                e.printStackTrace();
        } finally {
                if( r != null ) {
                        r.close();
                }


Performing crossvalidation

In this example, we show how we can compare classifiers built on different types of models and using different learning principles in a cross validation. Specifically, we create a position weight matrix, use that matrix to create a mixture model, and we create an inhomogeneous Markov model of order 3. We do so in the world of TrainableStatisticalModel s and in the world of DifferentiableStatisticalModel s. We then use the mixture model as foreground model and the inhomogeneous Markov model as the background model when building classifiers. The classifiers are learned by the generative MAP principle and the discriminative MSP principle, respectively. We then assess these classifiers in a 10-fold cross validation.


//create a DataSet for each class from the input data, using the DNA alphabet
DataSet[] data = new DataSet[2];
data[0] = new DNADataSet( args[0] );

//the length of our input sequences
int length = data[0].getElementLength();

data[1] = new DataSet( new DNADataSet( args[1] ), length );
 
AlphabetContainer container = data[0].getAlphabetContainer();

//create a new PWM
BayesianNetworkTrainSM pwm = new BayesianNetworkTrainSM( new BayesianNetworkTrainSMParameterSet(
                //the alphabet and the length of the model:
                container, length,
                //the equivalent sample size to compute hyper-parameters
                4,
                //some identifier for the model
                "my PWM",
                //we want a PWM, which is an inhomogeneous Markov model (IMM) of order 0
                ModelType.IMM, (byte) 0,
                //we want to estimate the MAP-parameters
                LearningType.ML_OR_MAP ) );
 
//create a new mixture model using 2 PWMs
MixtureTrainSM mixPwms = new MixtureTrainSM(
                //the length of the mixture model
                length,
                //the two components, which are PWMs
                new TrainableStatisticalModel[]{pwm,pwm},
                //the number of starts of the EM
                10,
                //the equivalent sample sizes
                new double[]{pwm.getESS(),pwm.getESS()},
                //the hyper-parameters to draw the initial sequence-specific component weights (hidden variables)
                1,
                //stopping criterion
                new SmallDifferenceOfFunctionEvaluationsCondition(1E-6),
                //parameterization of the model, LAMBDA complies with the
                //parameterization by log-probabilities
                Parameterization.LAMBDA);
 
//create a new inhomogeneous Markov model of order 3
BayesianNetworkTrainSM mm = new BayesianNetworkTrainSM(
                new BayesianNetworkTrainSMParameterSet( container, length, 256, "my iMM(3)", ModelType.IMM, (byte) 3, LearningType.ML_OR_MAP ) );
 
//create a new PWM scoring function
BayesianNetworkDiffSM dPwm = new BayesianNetworkDiffSM(
                //the alphabet and the length of the scoring function
                container, length,
                //the equivalent sample size for the plug-in parameters
                4,
                //we use plug-in parameters
                true,
                //a PWM is an inhomogeneous Markov model of order 0
                new InhomogeneousMarkov(0));
 
//create a new mixture scoring function
MixtureDiffSM dMixPwms = new MixtureDiffSM(
                //the number of starts
                2,
                //we use plug-in parameters
                true,
                //the two components, which are PWMs
                dPwm,dPwm);
 
//create a new scoring function that is an inhomogeneous Markov model of order 3
BayesianNetworkDiffSM dMm = new BayesianNetworkDiffSM(container, length, 4, true, new InhomogeneousMarkov(3));
 
//create the classifiers
int threads = AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors();
AbstractScoreBasedClassifier[] classifiers = new AbstractScoreBasedClassifier[]{
                                                           //model based with mixture model and Markov model
                                                           new TrainSMBasedClassifier( mixPwms, mm ),
                                                           //conditional likelihood based classifier
                                                           new MSPClassifier( new GenDisMixClassifierParameterSet(container, length,
                                                                           //method for optimizing the conditional likelihood and
                                                                           //other parameters of the numerical optimization
                                                                           Optimizer.QUASI_NEWTON_BFGS, 1E-2, 1E-2, 1, true, KindOfParameter.PLUGIN, false, threads),
                                                                           //mixture scoring function and Markov model scoring function
                                                                           dMixPwms,dMm )
};
 
//create an new k-fold cross validation using above classifiers
KFoldCrossValidation cv = new KFoldCrossValidation( classifiers );
 
//we use a specificity of 0.999 to compute the sensitivity and a sensitivity of 0.95 to compute FPR and PPV
NumericalPerformanceMeasureParameterSet mp = PerformanceMeasureParameterSet.createFilledParameters();
//we do a 10-fold cross validation and partition the data by means of the number of symbols
KFoldCrossValidationAssessParameterSet cvpars = new KFoldCrossValidationAssessParameterSet(PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS, length, true, 10);
 
//compute the result of the cross validation and print them to System.out
System.out.println( cv.assess( mp, cvpars, data ) );


Implementing a TrainableStatisticalModel

In this example, we show how to implement a new TrainableStatisticalModel. Here, we implement a simple homogeneous Markov models of order 0 to focus on the technical side of the implementation. A homogeneous Markov model of order 0 has parameters θa where a is a symbol of the alphabet Σ and \sum_{a \in \Sigma} \theta_a = 1. For an input sequence \mathbf{x} = x_1,\ldots,x_L it models the likelihood

\begin{align}

P(\mathbf{x}|\boldsymbol{\theta}) &= \prod_{l=1}^{L} \theta_{x_l}.

\end{align}

In the implementation, we use log-parameters logθa.


public class HomogeneousMarkovModel extends AbstractTrainableStatisticalModel {
 
        private double[] logProbs;//array for the parameters, i.e. the probabilities for each symbol
 
        public HomogeneousMarkovModel( AlphabetContainer alphabets ) throws Exception {
                super( alphabets, 0 ); //we have a homogeneous TrainableStatisticalModel, hence the length is set to 0
                //a homogeneous TrainableStatisticalModel can only handle simple alphabets
                if(! (alphabets.isSimple() && alphabets.isDiscrete()) ){
                        throw new Exception("Only simple and discrete alphabets allowed");
                }
                //initialize parameter array
                this.logProbs = new double[(int) alphabets.getAlphabetLengthAt( 0 )];
                Arrays.fill( logProbs, -Math.log(logProbs.length) );
        }
 
        public HomogeneousMarkovModel( StringBuffer stringBuff ) throws NonParsableException {
        super( stringBuff );
    }
 
        protected void fromXML( StringBuffer xml ) throws NonParsableException {
                //extract our XML-code
                xml = XMLParser.extractForTag( xml, "homogeneousMarkovModel" );
                //extract all the variables using XMLParser
                alphabets = (AlphabetContainer) XMLParser.extractObjectForTags( xml, "alphabets" );
                length = XMLParser.extractObjectForTags( xml, "length", int.class );
                logProbs = XMLParser.extractObjectForTags( xml, "logProbs", double[].class );
        }
 
        public StringBuffer toXML() {
                StringBuffer buf = new StringBuffer();
                //pack all the variables using XMLParser
                XMLParser.appendObjectWithTags( buf, alphabets, "alphabets" );
                XMLParser.appendObjectWithTags( buf, length, "length" );
                XMLParser.appendObjectWithTags( buf, logProbs, "logProbs" );
                //add our own tag
                XMLParser.addTags( buf, "homogeneousMarkovModel" );
                return buf;
        }
 
        public String getInstanceName() {
            return "Homogeneous Markov model of order 0";
        }
 
        public double getLogPriorTerm() throws Exception {
            //we use ML-estimation, hence no prior term
            return 0;
        }
 
        public NumericalResultSet getNumericalCharacteristics() throws Exception {
                //we do not have much to tell here
                return new NumericalResultSet(new NumericalResult("Number of parameters","The number of parameters this model uses",logProbs.length));
        }
 
        public double getLogProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception {
                double seqLogProb = 0.0;
                //compute the log-probability of the sequence between startpos and endpos (inclusive)
                //as sum of the single symbol log-probabilities
                for(int i=startpos;i<=endpos;i++){
                        //directly access the array by the numerical representation of the symbols
                        seqLogProb += logProbs[sequence.discreteVal( i )];
                }
                return seqLogProb;
        }
 
        public boolean isInitialized() {
        return true;
    }
 
        public void train( DataSet data, double[] weights ) throws Exception {
                //reset the parameter array
                Arrays.fill( logProbs, 0.0 );
                //default sequence weight
                double w = 1;
                //for each sequence in the data set
                for(int i=0;i<data.getNumberOfElements();i++){
                        //retrieve sequence
                        Sequence seq = data.getElementAt( i );
                        //if we do have any weights, use them
                        if(weights != null){
                                w = weights[i];
                        }
                        //for each position in the sequence
                        for(int j=0;j<seq.getLength();j++){
                                //count symbols, weighted by weights
                                logProbs[ seq.discreteVal( j ) ] += w;
                        }
                }
                //compute normalization
                double norm = 0.0;
                for(int i=0;i<logProbs.length;i++){ norm += logProbs[i]; }
                //normalize probs to obtain proper probabilities
                for(int i=0;i<logProbs.length;i++){ logProbs[i] = Math.log( logProbs[i]/norm ); }
        }
}


Implementing a DifferentiableStatisticalModel

In this example, we show how to implement a new DifferentiableStatisticalModel. Here, we implement a simple position weight matrix, i.e., an inhomogeneous Markov model of order 0. Since we want to use this position weight matrix in numerical optimization, we parameterize it in the so called "natural parameterization", where the probability of symbol a at position l is P(X_l=a | \boldsymbol{\lambda}) = \frac{\exp(\lambda_{l,a})}{ \sum_{\tilde{a}} \exp(\lambda_{l,\tilde{a}}) }. Since we use a product-Dirichlet prior on the parameters, we transformed this prior to the parameterization we use.

Here, the method getLogScore returns a log-score that can be normalized to a proper log-likelihood by subtracting a log-normalization constant. The log-score for an input sequence \mathbf{x} = x_1,\ldots,x_L essentially is

\begin{align}

S(\mathbf{x}|\boldsymbol{\lambda}) &= \sum_{l=1}^{L} \lambda_{l,x_l}.

\end{align}

The normalization constant is a partition function, i.e., the sum of the scores over all possible input sequences:

\begin{align}

Z(\boldsymbol{\lambda}) &= \sum_{\mathbf{x} \in \Sigma^L} \exp( S(\mathbf{x}|\boldsymbol{\lambda}) )

&= \sum_{\mathbf{x} \in \Sigma^L} \prod_{l=1}^{L} \exp(\lambda_{l,x_l})

&= \prod_{l=1}^{L} \sum_{a \in \Sigma} \exp(\lambda_{l,a})

\end{align}

Thus, the likelihood is defined as

\begin{align}

P(\mathbf{x}|\lambda) &= \frac{\exp(S(\mathbf{x}|\boldsymbol{\lambda}))}{Z(\boldsymbol{\lambda})}

\end{align}

and

\begin{align}

\log P(\mathbf{x}|\lambda) &= S(\mathbf{x}|\boldsymbol{\lambda})) - \log Z(\boldsymbol{\lambda}).

\end{align}


public class PositionWeightMatrixDiffSM extends AbstractDifferentiableStatisticalModel {

        private double[][] parameters;// array for the parameters of the PWM in natural parameterization
        private double ess;// the equivalent sample size
        private boolean isInitialized;// if the parameters of this PWM are initialized
        private Double norm;// normalization constant, must be reset for new parameter values
       
        public PositionWeightMatrixDiffSM( AlphabetContainer alphabets, int length, double ess ) throws IllegalArgumentException {
                super( alphabets, length );
                //we allow only discrete alphabets with the same symbols at all positions
                if(!alphabets.isSimple() || !alphabets.isDiscrete()){
                        throw new IllegalArgumentException( "This PWM can handle only discrete alphabets with the same alphabet at each position." );
                }
                //create parameter-array
                this.parameters = new double[length][(int)alphabets.getAlphabetLengthAt( 0 )];
                //set fields
                this.ess = ess;
                this.isInitialized = false;
                this.norm = null;
        }

        /**
         * @param xml
         * @throws NonParsableException
         */

        public PositionWeightMatrixDiffSM( StringBuffer xml ) throws NonParsableException {
                //super-constructor in the end calls fromXML(StringBuffer)
                //and checks that alphabet and length are set
                super( xml );
        }

        @Override
        public int getSizeOfEventSpaceForRandomVariablesOfParameter( int index ) {
                //the event space are the symbols of the alphabet
                return parameters[0].length;
        }

        @Override
        public double getLogNormalizationConstant() {
                //only depends on current parameters
                //-> compute only once
                if(this.norm == null){
                        norm = 0.0;
                        //sum over all sequences of product over all positions
                        //can be re-ordered for a PWM to the product over all positions
                        //of the sum over the symbols. In log-space the outer
                        //product becomes a sum, the inner sum must be computed
                        //by getLogSum(double[])
                        for(int i=0;i<parameters.length;i++){
                                norm += Normalisation.getLogSum( parameters[i] );
                        }
                }
                return norm;
        }

        @Override
        public double getLogPartialNormalizationConstant( int parameterIndex ) throws Exception {
                //norm computed?
                if(norm == null){
                        getLogNormalizationConstant();
                }
                //row and column of the parameter
                //in the PWM
                int symbol = parameterIndex%(int)alphabets.getAlphabetLengthAt( 0 );
                int position = parameterIndex/(int)alphabets.getAlphabetLengthAt( 0 );
                //partial derivation only at current position, rest is factor
                return norm - Normalisation.getLogSum( parameters[position] ) + parameters[position][symbol];
        }

        @Override
        public double getLogPriorTerm() {
                double logPrior = 0;
                for(int i=0;i<parameters.length;i++){
                        for(int j=0;j<parameters[i].length;j++){
                                //prior without gamma-normalization (only depends on hyper-parameters),
                                //uniform hyper-parameters (BDeu), tranformed prior density,
                                //without normalization constant (getLogNormalizationConstant()*ess subtracted later)
                                logPrior += ess/alphabets.getAlphabetLengthAt( 0 ) * parameters[i][j];
                        }
                }
                return logPrior;
        }

        @Override
        public void addGradientOfLogPriorTerm( double[] grad, int start ) throws Exception {
                for(int i=0;i<parameters.length;i++){
                        for(int j=0;j<parameters[i].length;j++,start++){
                                //partial derivations of the logPriorTerm above
                                grad[start] = ess/alphabets.getAlphabetLengthAt( 0 );
                        }
                }
        }

        @Override
        public double getESS() {
                return ess;
        }

        @Override
        public void initializeFunction( int index, boolean freeParams, DataSet[] data, double[][] weights ) throws Exception {
                if(!data[index].getAlphabetContainer().checkConsistency( alphabets ) ||
                                data[index].getElementLength() != length){
                        throw new IllegalArgumentException( "Alphabet or length to not match." );
                }
                //initially set pseudo-counts
                for(int i=0;i<parameters.length;i++){
                        Arrays.fill( parameters[i], ess/alphabets.getAlphabetLengthAt( 0 ) );
                }
                //counts in data
                for(int i=0;i<data[index].getNumberOfElements();i++){
                        Sequence seq = data[index].getElementAt( i );
                        for(int j=0;j<seq.getLength();j++){
                                parameters[j][ seq.discreteVal( j ) ] += weights[index][i];
                        }
                }
                for(int i=0;i<parameters.length;i++){
                        //normalize -> MAP estimation
                        Normalisation.sumNormalisation( parameters[i] );
                        //parameters are log-probabilities from MAP estimation
                        for(int j=0;j<parameters[i].length;j++){
                                parameters[i][j] = Math.log( parameters[i][j] );
                        }
                }
                norm = null;
                isInitialized = true;
        }

        @Override
        public void initializeFunctionRandomly( boolean freeParams ) throws Exception {
                int al = (int)alphabets.getAlphabetLengthAt( 0 );
                //draw parameters from prior density -> Dirichlet
                DirichletMRGParams pars = new DirichletMRGParams( ess/al, al );
                for(int i=0;i<parameters.length;i++){
                        parameters[i] = DirichletMRG.DEFAULT_INSTANCE.generate( al, pars );
                        //parameters are log-probabilities
                        for(int j=0;j<parameters[i].length;j++){
                                parameters[i][j] = Math.log( parameters[i][j] );
                        }
                }
                norm = null;
                isInitialized = true;
        }
       

        @Override
        public double getLogScoreFor( Sequence seq, int start ) {
                double score = 0.0;
                //log-score is sum of parameter values used
                //normalization to likelihood can be achieved
                //by subtracting getLogNormalizationConstant
                for(int i=0;i<parameters.length;i++){
                        score += parameters[i][ seq.discreteVal( i+start ) ];
                }
                return score;
        }

        @Override
        public double getLogScoreAndPartialDerivation( Sequence seq, int start, IntList indices, DoubleList partialDer ) {
                double score = 0.0;
                int off = 0;
                for(int i=0;i<parameters.length;i++){
                        int v = seq.discreteVal( i+start );
                        score += parameters[i][ v ];
                        //add index of parameter used to indices
                        indices.add( off + v );
                        //derivations are just one
                        partialDer.add( 1 );
                        off += parameters[i].length;
                }
                return score;
        }

        @Override
        public int getNumberOfParameters() {
                int num = 0;
                for(int i=0;i<parameters.length;i++){
                        num += parameters[i].length;
                }
                return num;
        }

        @Override
        public double[] getCurrentParameterValues() throws Exception {
                double[] pars = new double[getNumberOfParameters()];
                for(int i=0,k=0;i<parameters.length;i++){
                        for(int j=0;j<parameters[i].length;j++,k++){
                                pars[k] = parameters[i][j];
                        }
                }
                return pars;
        }

        @Override
        public void setParameters( double[] params, int start ) {
                for(int i=0;i<parameters.length;i++){
                        for(int j=0;j<parameters[i].length;j++,start++){
                                parameters[i][j] = params[start];
                        }
                }
                norm = null;
        }

        @Override
        public String getInstanceName() {
                return "Position weight matrix";
        }


        @Override
        public boolean isInitialized() {
                return isInitialized;
        }

        @Override
        public StringBuffer toXML() {
                StringBuffer xml = new StringBuffer();
                //store all fields with XML parser
                //including alphabet and length of the super-class
                XMLParser.appendObjectWithTags( xml, alphabets, "alphabets" );
                XMLParser.appendObjectWithTags( xml, length, "length" );
                XMLParser.appendObjectWithTags( xml, parameters, "parameters" );
                XMLParser.appendObjectWithTags( xml, isInitialized, "isInitialized" );
                XMLParser.appendObjectWithTags( xml, ess, "ess" );
                XMLParser.addTags( xml, "PWM" );
                return xml;
        }

        @Override
        protected void fromXML( StringBuffer xml ) throws NonParsableException {
                xml = XMLParser.extractForTag( xml, "PWM" );
                //extract all fields
                alphabets = (AlphabetContainer)XMLParser.extractObjectForTags( xml, "alphabets" );
                length = XMLParser.extractObjectForTags( xml, "length", int.class );
                parameters = (double[][])XMLParser.extractObjectForTags( xml, "parameters" );
                isInitialized = XMLParser.extractObjectForTags( xml, "isInitialized", boolean.class );
                ess = XMLParser.extractObjectForTags( xml, "ess", double.class );
        }
Personal tools