Changeset 1227 for trunk


Ignore:
Timestamp:
Mar 13, 2008, 3:43:41 AM (14 years ago)
Author:
Peter
Message:

fixes #341 and #93

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/ensemble_test.cc

    r1206 r1227  
    130130  *error << roc.score(target,out) << std::endl;
    131131
     132  {
     133    *error << "create ensemble" << std::endl;
     134    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup>
     135      ensemble(svm, kernel_lookup, sampler);
     136    *error << "test validate() before build()\n";
     137    ensemble.validate();
     138    std::vector<std::vector<statistics::Averager> > result;
     139    *error << "test predict() before build()\n";
     140    ensemble.predict(kernel_lookup, result);
     141  }
    132142  delete kf;
    133143
  • trunk/yat/classifier/EnsembleBuilder.h

    r1221 r1227  
    7373    virtual ~EnsembleBuilder(void);
    7474
    75     ///
    76     /// Generate ensemble. Function trains each member of the Ensemble.
    77     ///
     75    /**
     76       \brief Generate ensemble.
     77       
     78       Function trains each member of the Ensemble.
     79    */
    7880    void build(void);
    7981
    8082    ///
    81     /// @return classifier
     83    /// @return ith classifier
    8284    ///
    8385    const Classifier& classifier(size_t i) const;
    8486     
    8587    ///
    86     /// @return Number of classifiers in ensemble
     88    /// @return Number of classifiers in ensemble. Prior build(void)
     89    /// is issued size is zero.
    8790    ///
    8891    u_long size(void) const;
     
    156159  void EnsembleBuilder<C, D>::build(void)
    157160  {
    158     for(u_long i=0; i<subset_->size();++i) {
    159       C* classifier = mother_.make_classifier();
    160       classifier->train(subset_->training_data(i),
    161                         subset_->training_target(i));
    162       classifier_.push_back(classifier);
    163     }   
     161    if (classifier_.empty()){
     162      for(u_long i=0; i<subset_->size();++i) {
     163        C* classifier = mother_.make_classifier();
     164        classifier->train(subset_->training_data(i),
     165                          subset_->training_target(i));
     166        classifier_.push_back(classifier);
     167      }   
     168    }
    164169  }
    165170
     
    176181  (const D& data, std::vector<std::vector<statistics::Averager> >& result)
    177182  {
    178     result.clear();
    179     result.reserve(subset_->target().nof_classes());   
    180     for(size_t i=0; i<subset_->target().nof_classes();i++)
    181       result.push_back(std::vector<statistics::Averager>(data.columns()));
     183    result = std::vector<std::vector<statistics::Averager> >
     184      (subset_->target().nof_classes(),
     185       std::vector<statistics::Averager>(data.columns()));
    182186   
    183187    utility::Matrix prediction; 
    184188
    185     for(u_long k=0;k<subset_->size();++k) {       
     189    for(u_long k=0;k<size();++k) {       
    186190      D sub_data =  test_data(data, k);
    187191      classifier(k).predict(sub_data,prediction);
     
    246250  EnsembleBuilder<C, D>::validate(void)
    247251  {
    248     validation_result_.clear();
    249 
    250     validation_result_.reserve(subset_->target().nof_classes());   
    251     for(size_t i=0; i<subset_->target().nof_classes();i++)
    252       validation_result_.push_back(std::vector<statistics::Averager>(subset_->target().size()));
    253    
     252    // Don't recalculate validation_result_
     253    if (!validation_result_.empty())
     254      return validation_result_;
     255
     256    validation_result_ = std::vector<std::vector<statistics::Averager> >
     257      (subset_->target().nof_classes(),
     258       std::vector<statistics::Averager>(subset_->target().size()));
     259
    254260    utility::Matrix prediction; 
    255     for(u_long k=0;k<subset_->size();k++) {
     261    for(u_long k=0;k<size();k++) {
    256262      classifier(k).predict(subset_->validation_data(k),prediction);
    257263     
Note: See TracChangeset for help on using the changeset viewer.