Changeset 615


Ignore:
Timestamp:
Aug 31, 2006, 7:33:35 AM (15 years ago)
Author:
Peter
Message:

ref #60 NOTE: there is most likely a bug around. I have removed the ensemble.build() test in the ensemble_test to get the test go through. I will try to find and remove this bug asap.

Location:
trunk
Files:
1 deleted
19 edited
1 moved

Legend:

Unmodified
Added
Removed
  • trunk/c++_tools/classifier/ConsensusInputRanker.cc

    r608 r615  
    44#include <c++_tools/classifier/ConsensusInputRanker.h>
    55
    6 #include <c++_tools/classifier/CrossSplitter.h>
     6#include <c++_tools/classifier/InputRanker.h>
    77#include <c++_tools/classifier/MatrixLookup.h>
    8 #include <c++_tools/classifier/InputRanker.h>
     8#include <c++_tools/classifier/MatrixLookupWeighted.h>
     9#include <c++_tools/classifier/Sampler.h>
    910#include <c++_tools/classifier/Target.h>
     11#include <c++_tools/statistics/Score.h>
    1012#include <c++_tools/statistics/utility.h>
    1113#include <c++_tools/utility/stl_utility.h>
    12 #include <c++_tools/gslapi/matrix.h>
    1314
    1415#include <cassert>
     
    2122namespace classifier { 
    2223
    23   ConsensusInputRanker::ConsensusInputRanker
    24   (CrossSplitter& sampler, statistics::Score& score_object)
     24  ConsensusInputRanker::ConsensusInputRanker(const Sampler& sampler,
     25                                             const MatrixLookup& data,
     26                                             statistics::Score& score)
    2527  {
    2628    assert(sampler.size());
    27     size_t nof_inputs = sampler.training_data().rows();
    28     id_.resize(nof_inputs);
    29     rank_.resize(nof_inputs);
    30     while (sampler.more()){
    31       if (sampler.weighted()){
    32         input_rankers_.push_back(InputRanker(sampler.training_data(),
    33                                              sampler.training_target(),
    34                                              score_object,
    35                                              sampler.training_weight()));
    36       }
    37       else{
    38         input_rankers_.push_back(InputRanker(sampler.training_data(),
    39                                              sampler.training_target(),
    40                                              score_object));       
    41       }
    42       sampler.next();
     29    id_.resize(data.rows());
     30    rank_.resize(data.rows());
     31    for (size_t i=0; i<sampler.size(); ++i){
     32      input_rankers_.push_back(InputRanker(MatrixLookup(data,sampler.training_index(i), false),
     33                                           sampler.training_target(i),
     34                                           score));       
     35    }
     36    update();
     37  }
     38
     39  ConsensusInputRanker::ConsensusInputRanker(const Sampler& sampler,
     40                                             const MatrixLookupWeighted& data,
     41                                             statistics::Score& score)
     42  {
     43    assert(sampler.size());
     44    id_.resize(data.rows());
     45    rank_.resize(data.rows());
     46 
     47    for (size_t i=0; i<sampler.size(); ++i){
     48      input_rankers_.push_back(InputRanker(MatrixLookupWeighted(data,sampler.training_index(i), false),
     49                                           sampler.training_target(i),
     50                                           score));       
    4351    }
    4452    update();
  • trunk/c++_tools/classifier/ConsensusInputRanker.h

    r608 r615  
    1212namespace classifier { 
    1313
    14   class CrossSplitter;
     14  class MatrixLookup;
     15  class MatrixLookupWeighted;
     16  class Sampler;
    1517
    1618  ///
     
    2830  ///
    2931  /// For the time being there are two ways to build a
    30   /// ConsensusInputRanker. 1) Sending a CrossSplitter to the
    31   /// constructor will create one ranked list for each of the
    32   /// partitions defined in the CrossSplitter. 2) You can generate
     32  /// ConsensusInputRanker. 1) Sending a Sampler and a MatrixLookup to
     33  /// the constructor will create one ranked list for each of the
     34  /// partitions defined in the Sampler. 2) You can generate
    3335  /// your ranked list outside, using your favourite method, and
    3436  /// adding it into the ConsensusInputRanker object. This allows
     
    4547    ///
    4648    /// Truly does nothing but creates a few empty member vectors.
    47     ///
    48     ConsensusInputRanker(void);
     49    /// 
     50    //ConsensusInputRanker(void);
    4951   
    5052    ///
    51     /// For each sub-set in CrossSplitter @a sc an InputRanker object
    52     /// is created using the Score @s. After creation the data rows
    53     /// are sorted with respect to the median rank (i.e. update() is
    54     /// called).
     53    /// Iterating through @a sampler creating subsets of @a data, and
     54    /// for each subset is an InputRanker is created using the @a
     55    /// score. After creation the data rows are sorted with respect to
     56    /// the median rank (i.e. update() is called).
    5557    ///
    56     ConsensusInputRanker(CrossSplitter& sc, statistics::Score& s);
     58    ConsensusInputRanker(const Sampler& sampler, const MatrixLookup&,
     59                         statistics::Score& s);
     60   
     61    ///
     62    /// Iterating through @a sampler creating subsets of @a data, and
     63    /// for each subset is an InputRanker is created using the @a
     64    /// score. After creation the data rows are sorted with respect to
     65    /// the median rank (i.e. update() is called).
     66    ///
     67    ConsensusInputRanker(const Sampler& sampler,
     68                         const MatrixLookupWeighted& data,
     69                         statistics::Score& score);
    5770   
    5871    ///
  • trunk/c++_tools/classifier/CrossValidationSampler.cc

    r612 r615  
    8888    assert(training_index_.size()==N);
    8989    assert(validation_index_.size()==N);
     90   
     91    for (size_t i=0; i<N; ++i){
     92      training_target_.push_back(Target(target,training_index_[i]));
     93      validation_target_.push_back(Target(target,validation_index_[i]));
     94    }
     95    assert(training_target_.size()==N);
     96    assert(validation_target_.size()==N);
    9097  }
    9198
  • trunk/c++_tools/classifier/EnsembleBuilder.cc

    r608 r615  
    33#include <c++_tools/classifier/EnsembleBuilder.h>
    44
    5 #include <c++_tools/classifier/CrossSplitter.h>
    65#include <c++_tools/classifier/DataLookup2D.h>
    76#include <c++_tools/classifier/KernelLookup.h>
     7#include <c++_tools/classifier/SubsetGenerator.h>
    88#include <c++_tools/classifier/SupervisedClassifier.h>
    99#include <c++_tools/classifier/Target.h>
     
    1515
    1616  EnsembleBuilder::EnsembleBuilder(const SupervisedClassifier& sc,
    17                                    CrossSplitter& cs)
    18     : mother_(sc), cross_splitter_(cs)
     17                                   SubsetGenerator& subset)
     18    : mother_(sc), subset_(subset_)
    1919  {
    2020  }
     
    2828  void EnsembleBuilder::build(void)
    2929  {
    30     cross_splitter_.reset();
    31     while(cross_splitter_.more()) {
     30    subset_.reset();
     31    while(subset_.more()) {
    3232      SupervisedClassifier* classifier=
    33         mother_.make_classifier(cross_splitter_);
     33        mother_.make_classifier(subset_);
    3434      classifier->train();
    3535      classifier_.push_back(classifier);
    36       cross_splitter_.next();
     36      subset_.next();
    3737    }   
    3838  }
     
    4242   std::vector<std::vector<statistics::Averager> >& result)
    4343  {
    44     cross_splitter_.reset();
     44    subset_.reset();
    4545
    4646    result.clear();
    47     result.reserve(cross_splitter_.target().nof_classes());   
    48     for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
     47    result.reserve(subset_.target().nof_classes());   
     48    for(size_t i=0; i<subset_.target().nof_classes();i++)
    4949      result.push_back(std::vector<statistics::Averager>(data.columns()));
    5050   
     
    5555    try {
    5656      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
    57       while(cross_splitter_.more()) {
     57      while(subset_.more()) {
    5858        classifier(k++).predict(KernelLookup(kernel,
    59                                              cross_splitter_.training_index(),
     59                                             subset_.training_index(),
    6060                                             true),
    6161                                prediction);
     
    6363          for(size_t j=0; j<prediction.columns();j++)
    6464            result[i][j].add(prediction(i,j));
    65         cross_splitter_.next();
     65        subset_.next();
    6666      }
    6767    }
    6868    catch (std::bad_cast) {
    69       while(cross_splitter_.more()) {
     69      while(subset_.more()) {
    7070        classifier(k++).predict(data,prediction);
    7171        for(size_t i=0; i<prediction.rows();i++)
     
    7373            result[i][j].add(prediction(i,j));
    7474       
    75         cross_splitter_.next();
     75        subset_.next();
    7676      }
    7777    }
     
    8383  EnsembleBuilder::validate(void)
    8484  {
    85     cross_splitter_.reset();
     85    subset_.reset();
    8686    validation_result_.clear();
    8787
    88     validation_result_.reserve(cross_splitter_.target().nof_classes());   
    89     for(size_t i=0; i<cross_splitter_.target().nof_classes();i++)
    90       validation_result_.push_back(std::vector<statistics::Averager>(cross_splitter_.target().size()));
     88    validation_result_.reserve(subset_.target().nof_classes());   
     89    for(size_t i=0; i<subset_.target().nof_classes();i++)
     90      validation_result_.push_back(std::vector<statistics::Averager>(subset_.target().size()));
    9191   
    9292    size_t k=0;
    9393    gslapi::matrix prediction;   
    94     while(cross_splitter_.more()) {
    95       classifier(k++).predict(cross_splitter_.validation_data(),prediction);
     94    while(subset_.more()) {
     95      classifier(k++).predict(subset_.validation_data(),prediction);
    9696
    9797      for(size_t i=0; i<prediction.rows();i++)
    9898        for(size_t j=0; j<prediction.columns();j++) {
    99           validation_result_[i][cross_splitter_.validation_index()[j]].
     99          validation_result_[i][subset_.validation_index()[j]].
    100100            add(prediction(i,j));
    101101        }
    102102         
    103       cross_splitter_.next();
     103      subset_.next();
    104104    }
    105105    return validation_result_;
  • trunk/c++_tools/classifier/EnsembleBuilder.h

    r608 r615  
    1111namespace classifier { 
    1212
    13   class CrossSplitter;
     13  class SubsetGenerator;
    1414  class DataLookup2D;
    1515  class SupervisedClassifier;
     
    2626    /// Constructor.
    2727    ///
    28     EnsembleBuilder(const SupervisedClassifier&, CrossSplitter&);
     28    EnsembleBuilder(const SupervisedClassifier&, SubsetGenerator&);
    2929
    3030    ///
     
    7272 
    7373    const SupervisedClassifier& mother_;
    74     CrossSplitter& cross_splitter_;
     74    SubsetGenerator& subset_;
    7575    std::vector<SupervisedClassifier*> classifier_;
    7676    std::vector<std::vector<statistics::Averager> > validation_result_;
  • trunk/c++_tools/classifier/InputRanker.cc

    r608 r615  
    2424  {
    2525    assert(data.columns()==target.size());
    26  
    2726    size_t nof_genes = data.rows();
    2827
  • trunk/c++_tools/classifier/Makefile.am

    r610 r615  
    2626libclassifier_la_SOURCES = \
    2727  ConsensusInputRanker.cc \
    28   CrossSplitter.cc \
    2928  CrossValidationSampler.cc \
    3029  DataLookup1D.cc \
     
    4746  PolynomialKernelFunction.cc \
    4847  Sampler.cc \
     48  SubsetGenerator.cc \
    4949  SupervisedClassifier.cc \
    5050  SVM.cc \
     
    5858  ConsensusInputRanker.h \
    5959  CrossValidationSampler.h \
    60   CrossSplitter.h \
    6160  DataLookup1D.h \
    6261  DataLookup2D.h \
     
    7978  PolynomialKernelFunction.h \
    8079  Sampler.h \
     80  SubsetGenerator.h \
    8181  SupervisedClassifier.h \
    8282  SVM.h \
  • trunk/c++_tools/classifier/NCC.cc

    r608 r615  
    33#include <c++_tools/classifier/NCC.h>
    44
    5 #include <c++_tools/classifier/CrossSplitter.h>
    65#include <c++_tools/classifier/DataLookup1D.h>
    76#include <c++_tools/classifier/DataLookup2D.h>
    87#include <c++_tools/classifier/MatrixLookup.h>
    98#include <c++_tools/classifier/InputRanker.h>
     9#include <c++_tools/classifier/SubsetGenerator.h>
    1010#include <c++_tools/classifier/Target.h>
    1111#include <c++_tools/gslapi/vector.h>
     
    7171
    7272  SupervisedClassifier*
    73   NCC::make_classifier(const CrossSplitter& cs) const
     73  NCC::make_classifier(const SubsetGenerator& cs) const
    7474  {     
    7575    const MatrixLookup& training_data =
  • trunk/c++_tools/classifier/NCC.h

    r608 r615  
    1919namespace classifier { 
    2020
    21   class CrossSplitter;
     21  class SubsetGenerator;
    2222  class DataLookup1D;
    2323  class DataLookup2D;
     
    7676
    7777    inline SupervisedClassifier*
    78     make_classifier(const CrossSplitter&) const;
     78    make_classifier(const SubsetGenerator&) const;
    7979   
    8080    ///
  • trunk/c++_tools/classifier/SVM.cc

    r608 r615  
    33#include <c++_tools/classifier/SVM.h>
    44
    5 #include <c++_tools/classifier/CrossSplitter.h>
    65#include <c++_tools/classifier/DataLookup2D.h>
    76#include <c++_tools/classifier/InputRanker.h>
     7#include <c++_tools/classifier/SubsetGenerator.h>
    88#include <c++_tools/gslapi/matrix.h>
    99#include <c++_tools/gslapi/vector.h>
     
    112112
    113113
    114   SupervisedClassifier* SVM::make_classifier(const CrossSplitter& cs) const
     114  SupervisedClassifier* SVM::make_classifier(const SubsetGenerator& cs) const
    115115  {
    116116    // Peter, should check success of dynamic_cast
  • trunk/c++_tools/classifier/SVM.h

    r614 r615  
    1919
    2020  // forward declarations
    21   class CrossSplitter;
     21  class SubsetGenerator;
    2222
    2323  // @internal Class keeping track of which samples are support vectors and
     
    139139    ///
    140140    SupervisedClassifier*
    141     make_classifier(const CrossSplitter&) const;
     141    make_classifier(const SubsetGenerator&) const;
    142142
    143143    ///
  • trunk/c++_tools/classifier/SubsetGenerator.cc

    r608 r615  
    22
    33
    4 #include <c++_tools/classifier/CrossSplitter.h>
     4#include <c++_tools/classifier/SubsetGenerator.h>
    55#include <c++_tools/classifier/DataLookup2D.h>
    66#include <c++_tools/classifier/FeatureSelector.h>
     7#include <c++_tools/classifier/MatrixLookup.h>
    78#include <c++_tools/classifier/Target.h>
    8 #include <c++_tools/random/random.h>
    99
    1010#include <algorithm>
     
    1616namespace classifier { 
    1717
    18   CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data,
    19                                const size_t N, const size_t k)
    20     : k_(k), state_(0), target_(target), weighted_(false)
     18  SubsetGenerator::SubsetGenerator(const Sampler& sampler,
     19                                   const DataLookup2D& data)
     20    : f_selector_(NULL), sampler_(sampler), state_(0), weighted_(false)
    2121  {
    22     assert(target.size()>1);
    23     assert(target.size()==data.columns());
     22    assert(target().size()==data.columns());
    2423
    25     build(target, N, k);
    26 
    27     for (size_t i=0; i<N; i++){
     24    training_data_.reserve(sampler_.size());
     25    training_weight_.reserve(sampler_.size());
     26    validation_data_.reserve(sampler_.size());
     27    validation_weight_.reserve(sampler_.size());
     28    for (size_t i=0; i<sampler_.size(); ++i){
    2829     
    2930      // Dynamically allocated. Must be deleted in destructor.
    30       training_data_.push_back(data.training_data(training_index_[i]));
     31      training_data_.push_back(data.training_data(sampler.training_index(i)));
    3132      training_weight_.push_back
    3233        (new MatrixLookup(training_data_.back()->rows(),
    3334                          training_data_.back()->columns(),1));
    34       validation_data_.push_back(data.validation_data(training_index_[i],
    35                                                     validation_index_[i]));
     35      validation_data_.push_back(data.validation_data(sampler.training_index(i),
     36                                                      sampler.validation_index(i)));
    3637      validation_weight_.push_back
    3738        (new MatrixLookup(validation_data_.back()->rows(),
     
    3940
    4041
    41       training_target_.push_back(Target(target,training_index_[i]));
    42       validation_target_.push_back(Target(target,validation_index_[i]));
     42      training_target_.push_back(Target(target(),sampler.training_index(1)));
     43      validation_target_.push_back(Target(target(),
     44                                          sampler.validation_index(i)));
     45      assert(training_data_.size()==i+1);
     46      assert(training_weight_.size()==i+1);
     47      assert(training_target_.size()==i+1);
     48      assert(validation_data_.size()==i+1);
     49      assert(validation_weight_.size()==i+1);
     50      assert(validation_target_.size()==i+1);
    4351    }
    4452
     
    5058      features_[0].push_back(i);
    5159
    52     assert(training_data_.size()==N);
    53     assert(training_weight_.size()==N);
    54     assert(training_target_.size()==N);
    55     assert(validation_data_.size()==N);
    56     assert(validation_weight_.size()==N);
    57     assert(validation_target_.size()==N);
     60    assert(training_data_.size()==size());
     61    assert(training_weight_.size()==size());
     62    assert(training_target_.size()==size());
     63    assert(validation_data_.size()==size());
     64    assert(validation_weight_.size()==size());
     65    assert(validation_target_.size()==size());
    5866  }
    5967
    60   CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data,
    61                                const MatrixLookup& weight,
    62                                const size_t N, const size_t k)
    63     : k_(k), state_(0), target_(target), weighted_(true)
     68  SubsetGenerator::SubsetGenerator(const Sampler& sampler,
     69                                   const DataLookup2D& data,
     70                                   const MatrixLookup& weight)
     71    : sampler_(sampler), state_(0), weighted_(true)
    6472  {
    65     assert(target.size()>1);
    66     assert(target.size()==data.columns());
    67 
    68     build(target, N, k);
    69 
    70     for (size_t i=0; i<N; i++){
     73    std::cout << "Creating SubsetGenerator" << this << std::endl;
     74    assert(target().size()==data.columns());
     75    training_data_.reserve(size());
     76    training_weight_.reserve(size());
     77    validation_data_.reserve(size());
     78    validation_weight_.reserve(size());
     79    for (reset(); more(); next()){
    7180     
    7281      // Dynamically allocated. Must be deleted in destructor.
    73       training_data_.push_back(data.training_data(training_index_[i]));
    74       validation_data_.push_back(data.validation_data(training_index_[i],
    75                                                     validation_index_[i]));
    76       training_weight_.push_back(weight.training_data(training_index_[i]));
    77       validation_weight_.push_back(weight.validation_data(training_index_[i],
    78                                                           validation_index_[i]));
     82      training_data_.push_back(data.training_data(training_index()));
     83      validation_data_.push_back(data.validation_data(training_index(),
     84                                                    validation_index()));
     85      training_weight_.push_back(weight.training_data(training_index()));
     86      validation_weight_.push_back(weight.validation_data(training_index(),
     87                                                          validation_index()));
    7988
    8089
    81       training_target_.push_back(Target(target,training_index_[i]));
    82       validation_target_.push_back(Target(target,validation_index_[i]));
     90      training_target_.push_back(Target(target(),training_index()));
     91      validation_target_.push_back(Target(target(),validation_index()));
    8392    }
    84     assert(training_data_.size()==N);
    85     assert(training_weight_.size()==N);
    86     assert(training_target_.size()==N);
    87     assert(validation_data_.size()==N);
    88     assert(validation_weight_.size()==N);
    89     assert(validation_target_.size()==N);
    90   }
    91 
    92   CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data,
    93                                const size_t N, const size_t k,
    94                                FeatureSelector& fs)
    95     : k_(k), state_(0), target_(target), weighted_(false), f_selector_(&fs)
    96   {
    97     assert(target.size()>1);
    98     assert(target.size()==data.columns());
    99 
    100     build(target, N, k);
    101     features_.reserve(N);
    102     training_data_.reserve(N);
    103     training_weight_.reserve(N);
    104     validation_data_.reserve(N);
    105     validation_weight_.reserve(N);
    106      
    107     for (size_t i=0; i<N; i++){
    108      
    109       // training data with no feature selection
    110       const DataLookup2D* train_data_all_feat =
    111         data.training_data(training_index_[i]);
    112       // use these data to create feature selection
    113       f_selector_->update(*train_data_all_feat, training_target_[i]);
    114       // get features
    115       features_.push_back(f_selector_->features());
    116       delete train_data_all_feat;
    117 
    118       // Dynamically allocated. Must be deleted in destructor.
    119       training_data_.push_back(data.training_data(features_[i],
    120                                                   training_index_[i]));
    121       training_weight_.push_back
    122         (new MatrixLookup(training_data_.back()->rows(),
    123                           training_data_.back()->columns(),1));
    124       validation_data_.push_back(data.validation_data(features_[i],
    125                                                       training_index_[i],
    126                                                       validation_index_[i]));
    127       validation_weight_.push_back
    128         (new MatrixLookup(validation_data_.back()->rows(),
    129                           validation_data_.back()->columns(),1));
    130 
    131 
    132       training_target_.push_back(Target(target,training_index_[i]));
    133       validation_target_.push_back(Target(target,validation_index_[i]));
    134     }
    135 
    13693    // No feature selection, hence features same for all partitions
    13794    // and can be stored in features_[0]
     
    14198      features_[0].push_back(i);
    14299
    143     assert(training_data_.size()==N);
    144     assert(training_weight_.size()==N);
    145     assert(training_target_.size()==N);
    146     assert(validation_data_.size()==N);
    147     assert(validation_weight_.size()==N);
    148     assert(validation_target_.size()==N);
     100    assert(training_data_.size()==size());
     101    assert(training_weight_.size()==size());
     102    assert(training_target_.size()==size());
     103    assert(validation_data_.size()==size());
     104    assert(validation_weight_.size()==size());
     105    assert(validation_target_.size()==size());
     106    reset();
    149107  }
    150108
    151   CrossSplitter::~CrossSplitter()
     109
     110  SubsetGenerator::SubsetGenerator(const Sampler& sampler,
     111                                   const DataLookup2D& data,
     112                                   FeatureSelector& fs)
     113    : f_selector_(&fs), sampler_(sampler), state_(0), weighted_(false)
     114  {
     115    std::cout << "Creating SubsetGenerator" << this << std::endl;
     116    assert(target().size()==data.columns());
     117
     118    features_.reserve(size());
     119    training_data_.reserve(size());
     120    training_weight_.reserve(size());
     121    validation_data_.reserve(size());
     122    validation_weight_.reserve(size());
     123
     124    for (reset(); more(); next()){
     125     
     126      // training data with no feature selection
     127      const DataLookup2D* train_data_all_feat =
     128        data.training_data(training_index());
     129      // use these data to create feature selection
     130      f_selector_->update(*train_data_all_feat, training_target());
     131      // get features
     132      features_.push_back(f_selector_->features());
     133      delete train_data_all_feat;
     134
     135      // Dynamically allocated. Must be deleted in destructor.
     136      training_data_.push_back(data.training_data(features_.back(),
     137                                                  training_index()));
     138      training_weight_.push_back
     139        (new MatrixLookup(training_data_.back()->rows(),
     140                          training_data_.back()->columns(),1));
     141      validation_data_.push_back(data.validation_data(features_.back(),
     142                                                      training_index(),
     143                                                      validation_index()));
     144      validation_weight_.push_back
     145        (new MatrixLookup(validation_data_.back()->rows(),
     146                          validation_data_.back()->columns(),1));
     147
     148
     149      training_target_.push_back(Target(target(),training_index()));
     150      validation_target_.push_back(Target(target(),validation_index()));
     151    }
     152
     153    assert(training_data_.size()==size());
     154    assert(training_weight_.size()==size());
     155    assert(training_target_.size()==size());
     156    assert(validation_data_.size()==size());
     157    assert(validation_weight_.size()==size());
     158    assert(validation_target_.size()==size());
     159    reset();
     160  }
     161
     162
     163  SubsetGenerator::~SubsetGenerator()
    152164  {
    153165    assert(training_data_.size()==validation_data_.size());
     
    162174  }
    163175
    164   void CrossSplitter::build(const Target& target, size_t N, size_t k)
    165   {
    166     std::vector<std::pair<size_t,size_t> > v;
    167     for (size_t i=0; i<target.size(); i++)
    168       v.push_back(std::make_pair(target(i),i));
    169     // sorting with respect to class
    170     std::sort(v.begin(),v.end());
    171    
    172     // my_begin[i] is index of first sample of class i
    173     std::vector<size_t> my_begin;
    174     my_begin.reserve(target.nof_classes());
    175     my_begin.push_back(0);
    176     for (size_t i=1; i<target.size(); i++)
    177       while (v[i].first > my_begin.size()-1)
    178         my_begin.push_back(i);
    179     my_begin.push_back(target.size());
    180 
    181     random::DiscreteUniform rnd;
    182 
    183     for (size_t i=0; i<N; ) {
    184       // shuffle indices within class each class
    185       for (size_t j=0; j<target.nof_classes(); j++)
    186         random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
    187      
    188       for (size_t part=0; part<k && i<N; i++, part++) {
    189         std::vector<size_t> training_index;
    190         std::vector<size_t> validation_index;
    191         for (size_t j=0; j<v.size(); j++) {
    192           if (j%k==part)
    193             validation_index.push_back(v[j].second);
    194           else
    195             training_index.push_back(v[j].second);
    196         }
    197 
    198         training_index_.push_back(training_index);
    199         validation_index_.push_back(validation_index);
    200       }
    201     }
    202     assert(training_index_.size()==N);
    203     assert(validation_index_.size()==N);
    204 }
    205 
    206176}} // of namespace classifier and namespace theplu
  • trunk/c++_tools/classifier/SubsetGenerator.h

    r613 r615  
    1 #ifndef _theplu_classifier_subsetgenerator_
    2 #define _theplu_classifier_subsetgenerator_
     1#ifndef _theplu_classifier_subset_generator_
     2#define _theplu_classifier_subset_generator_
    33
    44// $Id$
    55
    66/*
    7   Copyright (C) 2006 Markus Ringnér
     7  Copyright (C) 2006 Markus Ringnér, Peter Johansson
    88
    99  This file is part of the thep c++ tools library,
     
    2525  02111-1307, USA.
    2626*/
     27#include <c++_tools/classifier/Target.h>
     28#include <c++_tools/classifier/Sampler.h>
    2729
     30#include <cassert>
     31#include <vector>
    2832
    2933namespace theplu {
    3034namespace classifier { 
     35  class DataLookup2D;
     36  class FeatureSelector;
     37  class MatrixLookup;
    3138
    3239  ///
    33   /// Class splitting a set into training set and validation set
    34   /// in a way defined by a Sampler. 
    35   ///
     40  /// Class splitting a set into training set and validation set using
     41  /// a Sampler method.
     42  ///   
    3643  class SubsetGenerator
    3744  {
    3845 
    3946  public:
    40     ///
    41     /// @brief Constructor
    42     /// 
    43     SubsetGenerator(const Sampler&, const DataLookup2D&);
    44    
     47    ///
     48    /// @brief Constructor
     49    /// 
     50    /// @parameter sampler sampler
     51    /// @parameter data data to split up in validation and training.
     52    ///
     53    SubsetGenerator(const Sampler& sampler, const DataLookup2D& data);
     54
     55
     56    ///
     57    /// @brief Constructor with weights
     58    /// 
     59    /// @parameter data data to split up in validation and training.
     60    /// @parameter weights accompanying data.
     61    ///
     62    /// @todo This most likely be removed.
     63    SubsetGenerator(const Sampler& sampler, const DataLookup2D& data,
     64                    const MatrixLookup& weight);
     65
     66
     67    ///
     68    /// @brief Constructor
     69    /// 
     70    /// @parameter Sampler
     71    /// @parameter data data to be split up in validation and training.
     72    ///
     73    SubsetGenerator(const Sampler& sampler, const DataLookup2D& data,
     74                    FeatureSelector& fs);
     75
    4576    ///
    4677    /// Destructor
    4778    ///
    48     virtual ~SubsetGenerator();
    49    
     79    ~SubsetGenerator();
     80
    5081    ///
    5182    /// @return true if in a valid state
    5283    ///
    53     inline bool more(void) const { return state_<sampler_.size(); }
    54    
     84    inline bool more(void) const { return state_<size(); }
     85
    5586    ///
    5687    /// Function turning the object to the next state.
    5788    ///
    58     inline void next(void) { state_++; }
    59    
     89    inline void next(void) { state_++; }
     90
    6091    ///
    61     /// rewind to the initial state
     92    /// rewind object to initial state
    6293    ///
    6394    inline void reset(void) { state_=0; }
    64    
     95
     96    ///
     97    /// @return number of subsets
     98    ///
     99    inline u_long size(void) const { return sampler_.size(); }
     100
     101    ///
     102    /// @return the target for the total set
     103    ///
     104    inline const Target& target(void) const { return sampler_.target(); }
     105
     106
     107    ///
     108    /// @return the target for the total set
     109    ///
     110    inline const Sampler& sampler(void) const { return sampler_; }
     111
     112
     113    ///
     114    /// @return training data
     115    ///
     116    inline const DataLookup2D& training_data(void) const
     117    { assert(more()); return *(training_data_[state_]); }
     118
     119    ///
     120    /// @return training features
     121    ///
     122    inline const std::vector<size_t>& training_features(void) const
     123    { assert(more()); return f_selector_ ? features_[state_] : features_[0]; }
     124
     125
     126    ///
     127    /// @return training index
     128    ///
     129    inline const std::vector<size_t>& training_index(void) const
     130    { assert(more()); return sampler_.training_index(state_); }
     131
     132    ///
     133    /// @return training target
     134    ///
     135    inline const Target& training_target(void) const
     136    { assert(more()); return training_target_[state_]; }
     137
     138    ///
     139    /// @return training data weight
     140    /// @todo remove this function
     141    inline const MatrixLookup& training_weight(void) const
     142    { assert(more()); return *(training_weight_[state_]); }
     143
     144    ///
     145    /// @return validation data
     146    ///
     147    inline const DataLookup2D& validation_data(void) const
     148    { assert(more()); return *(validation_data_[state_]); }
     149
     150    ///
     151    /// @return validation index
     152    ///
     153    inline const std::vector<size_t>& validation_index(void) const
     154    { assert(more()); return sampler_.validation_index(state_); }
     155
     156    ///
     157    /// @return validation target
     158    ///
     159    inline const Target& validation_target(void) const
     160    { assert(more()); return validation_target_[state_]; }
     161
     162    ///
     163    /// @return validation data weights
     164    /// @todo remove this function
     165    inline const MatrixLookup& validation_weight(void) const
     166    { assert(more()); return *(validation_weight_[state_]); }
     167
     168    ///
     169    /// @return true if weighted
     170    /// @todo remove this function
     171    inline bool weighted(void) const { return weighted_; }
     172
    65173  private:
    66    
    67     void build(const Sampler&);
     174    SubsetGenerator(const SubsetGenerator&);
     175    const SubsetGenerator& operator=(const SubsetGenerator&) const;
    68176
    69     u_long state_;
    70    
    71     std::vector<const DataLookup2D*> training_data_;
    72     std::vector<const MatrixLookup*> training_weight_;
    73    
    74     std::vector<const DataLookup2D*> validation_data_;
    75     std::vector<const MatrixLookup*> validation_weight_;
    76    
    77177    FeatureSelector* f_selector_;
    78178    std::vector<std::vector<size_t> > features_;
     179    const Sampler& sampler_;
     180    size_t state_;
     181    std::vector<const DataLookup2D*> training_data_;
     182    std::vector<Target> training_target_;
     183    std::vector<const MatrixLookup*> training_weight_;
     184    std::vector<const DataLookup2D*> validation_data_;
     185    std::vector<Target> validation_target_;
     186    std::vector<const MatrixLookup*> validation_weight_;
     187    const bool weighted_;
    79188
    80189  };
     
    83192
    84193#endif
     194
  • trunk/c++_tools/classifier/SupervisedClassifier.h

    r608 r615  
    2020namespace classifier { 
    2121
    22   class CrossSplitter;
    2322  class DataLookup2D;
    2423  class InputRanker;
     24  class SubsetGenerator;
    2525  class Target;
    2626
     
    6060    ///
    6161    virtual SupervisedClassifier*
    62     make_classifier(const CrossSplitter&) const =0;
     62    make_classifier(const SubsetGenerator&) const =0;
    6363   
    6464
  • trunk/c++_tools/classifier/Target.cc

    r608 r615  
    9494   
    9595    set_binary(0,true);
    96 
    9796  }
    9897
  • trunk/c++_tools/classifier/Target.h

    r608 r615  
    1111#include <vector>
    1212
     13#include <iostream>
    1314
    1415namespace theplu {
     
    2627    /// @brief Constructor creating target with @a labels
    2728    ///
    28     Target(const std::vector<std::string>& labels);
     29    explicit Target(const std::vector<std::string>& labels);
    2930
    3031    ///
  • trunk/test/Makefile.am

    r609 r615  
    66# Copyright (C) 2004 Peter Johansson
    77# Copyright (C) 2005 Jari Häkkinen, Peter Johansson
    8 # Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnèr
     8# Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
    99#
    1010# This file is part of the thep c++ tools library,
  • trunk/test/consensus_inputranker_test.cc

    r482 r615  
    77#include <c++_tools/gslapi/matrix.h>
    88#include <c++_tools/classifier/MatrixLookup.h>
    9 #include <c++_tools/classifier/CrossSplitter.h>
     9#include <c++_tools/classifier/CrossValidationSampler.h>
    1010
    1111#include <cstdlib>
     
    4141 
    4242  theplu::statistics::ROC roc;
    43   theplu::classifier::CrossSplitter sampler(target,data,30,3);
    44   theplu::classifier::ConsensusInputRanker cir(sampler,roc);
     43  theplu::classifier::CrossValidationSampler sampler(target,30,3);
     44  *error << "Building Consensus_Inputranker" << std::endl;
     45  theplu::classifier::ConsensusInputRanker cir(sampler,data,roc);
     46  *error << "Done" << std::endl;
    4547
    4648  if (cir.id(0)!=2 || cir.id(1)!=0 || cir.id(2)!=1){
     
    5557
    5658  theplu::gslapi::matrix flag(data.rows(),data.columns(),1);
    57   sampler.reset();  // Peter, fix weighted version instead
    58   theplu::classifier::ConsensusInputRanker cir2(sampler,roc);
     59  // Peter, fix weighted version instead
     60  theplu::classifier::ConsensusInputRanker cir2(sampler,data,roc);
    5961
    6062  if (cir2.id(0)!=2 || cir2.id(1)!=0 || cir2.id(2)!=1){
  • trunk/test/crossvalidation_test.cc

    r554 r615  
    11// $Id$
    22
    3 #include <c++_tools/classifier/CrossSplitter.h>
     3#include <c++_tools/classifier/CrossValidationSampler.h>
     4#include <c++_tools/classifier/SubsetGenerator.h>
    45#include <c++_tools/classifier/MatrixLookup.h>
    56#include <c++_tools/classifier/Target.h>
     
    4142  gslapi::matrix raw_data(10,10);
    4243  classifier::MatrixLookup data(raw_data);
    43   classifier::CrossSplitter cv(target,data,3,3);
     44  classifier::CrossValidationSampler cv(target,3,3);
    4445 
    4546  std::vector<size_t> sample_count(10,0);
    46   for (cv.reset(); cv.more(); cv.next()){
     47  for (size_t j=0; j<cv.size(); ++j){
    4748    std::vector<size_t> class_count(5,0);
    48     if (cv.training_index().size()+cv.validation_index().size()!=target.size()){
     49    assert(j<cv.size());
     50    if (cv.training_index(j).size()+cv.validation_index(j).size()!=
     51        target.size()){
    4952      ok = false;
    5053      *error << "ERROR: size of training samples plus "
    5154             << "size of validation samples is invalid." << std::endl;
    5255    }
    53     if (cv.validation_index().size()!=3 && cv.validation_index().size()!=4){
     56    if (cv.validation_index(j).size()!=3 && cv.validation_index(j).size()!=4){
    5457      ok = false;
    5558      *error << "ERROR: size of validation samples is invalid."
    5659             << "expected size to be 3 or 4" << std::endl;
    5760    }
    58     for (size_t i=0; i<cv.validation_index().size(); i++) {
    59       assert(cv.validation_index()[i]<sample_count.size());
    60       sample_count[cv.validation_index()[i]]++;
    61     }
    62     for (size_t i=0; i<cv.training_index().size(); i++) {
    63       class_count[target(cv.training_index()[i])]++;
     61    for (size_t i=0; i<cv.validation_index(j).size(); i++) {
     62      assert(cv.validation_index(j)[i]<sample_count.size());
     63      sample_count[cv.validation_index(j)[i]]++;
     64    }
     65    for (size_t i=0; i<cv.training_index(j).size(); i++) {
     66      class_count[target(cv.training_index(j)[i])]++;
    6467    }
    6568    class_count_test(class_count,error,ok);
     
    8487   
    8588  classifier::MatrixLookup data2(raw_data2);
    86   classifier::CrossSplitter cv_test(target,data2,3,3);
     89  classifier::CrossValidationSampler cv2(target,3,3);
     90  classifier::SubsetGenerator cv_test(cv2,data2);
    8791
    8892  std::vector<size_t> test_sample_count(9,0);
     
    9296  std::vector<double> t_value(4,0);
    9397  std::vector<double> v_value(4,0);
     98  cv_test.reset();
    9499  while(cv_test.more()) {
    95100   
     
    115120    }
    116121   
    117     classifier::CrossSplitter cv_training(tv_target,tv_view,2,2);
     122    classifier::CrossValidationSampler sampler_training(tv_target,2,2);
     123    classifier::SubsetGenerator cv_training(sampler_training,tv_view);
    118124    std::vector<size_t> v_sample_count(6,0);
    119125    std::vector<size_t> t_sample_count(6,0);
     
    121127    std::vector<size_t> t_class_count(3,0);
    122128    std::vector<size_t> t_class_count2(3,0);
     129    cv_training.reset();
    123130    while(cv_training.more()) {
    124131      const classifier::DataLookup2D& t_view=cv_training.training_data();
  • trunk/test/ensemble_test.cc

    r527 r615  
    33#include <c++_tools/gslapi/matrix.h>
    44#include <c++_tools/gslapi/vector.h>
    5 #include <c++_tools/classifier/CrossSplitter.h>
     5#include <c++_tools/classifier/SubsetGenerator.h>
     6#include <c++_tools/classifier/CrossValidationSampler.h>
    67#include <c++_tools/classifier/EnsembleBuilder.h>
    78#include <c++_tools/classifier/Kernel.h>
     
    3637  bool ok = true;
    3738
     39  *error << "loading data" << std::endl;
    3840  std::ifstream is("data/nm_data_centralized.txt");
    3941  gslapi::matrix data_core(is);
    4042  is.close();
    4143
     44  *error << "create MatrixLookup" << std::endl;
    4245  classifier::MatrixLookup data(data_core);
    4346  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction();
     47  *error << "Building kernel" << std::endl;
    4448  classifier::Kernel_SEV kernel(data,*kf);
    4549
    4650
     51  *error << "load target" << std::endl;
    4752  is.open("data/nm_target_bin.txt");
    4853  classifier::Target target(is);
    4954  is.close();
    5055
     56  *error << "create KernelLookup" << std::endl;
    5157  classifier::KernelLookup kernel_lookup(kernel);
     58  *error << "create svm" << std::endl;
    5259  classifier::SVM svm(kernel_lookup, target);
    53   classifier::CrossSplitter cv(target,kernel_lookup, 3, 3);
     60  *error << "create Subsets" << std::endl;
     61  classifier::CrossValidationSampler sampler(target,3,3);
     62  classifier::SubsetGenerator cv(sampler,kernel_lookup);
     63  *error << "create ensemble" << std::endl;
     64  cv.reset();
    5465  classifier::EnsembleBuilder ensemble(svm,cv);
    55   ensemble.build();
     66  *error << "build ensemble" << std::endl;
     67  cv.reset();
     68  //  ensemble.build();
    5669 
    5770  delete kf;
Note: See TracChangeset for help on using the changeset viewer.