Changeset 123


Ignore:
Timestamp:
Jul 23, 2004, 11:55:54 AM (18 years ago)
Author:
Peter
Message:

modified to be balanced between classes

Location:
trunk/src
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/src/CrossValidation.cc

    r115 r123  
    1111namespace cpptools { 
    1212
    13   CrossValidation::CrossValidation(const size_t nof_samples, const size_t k)
    14     :count_(0),index_(std::vector<size_t>()), k_(k)
     13  CrossValidation::CrossValidation(const theplu::gslapi::vector& target,
     14                                   const size_t k)
     15    :count_(0),index_negative_(std::vector<size_t>()),
     16     index_positive_(std::vector<size_t>()), k_(k)
    1517 
    1618  {
    17     index_.resize(nof_samples);
    18     for (size_t i=0; i<nof_samples; i++)
    19       index_[i]=i;
     19    for (size_t i=0; i<target.size(); i++){
     20      if (target(i)==1)
     21        index_positive_.push_back(i);
     22      else
     23        index_negative_.push_back(i);
     24    }
     25
    2026    my_uniform_rng a;
    21     random_shuffle(index_.begin(), index_.end(), a);
     27    random_shuffle(index_negative_.begin(), index_negative_.end(), a);
     28    random_shuffle(index_positive_.begin(), index_positive_.end(), a);
    2229  }
    2330
     
    2734      count_=0;
    2835      my_uniform_rng a;
    29       random_shuffle(index_.begin(), index_.end(), a);
     36      random_shuffle(index_negative_.begin(), index_negative_.end(), a);
     37      random_shuffle(index_positive_.begin(), index_positive_.end(), a);
    3038    }
     39     
     40    count_++;
     41    std::vector<size_t> training_set;
     42
     43    size_t begin = int(index_positive_.size()*(count_-1)/k_);
     44    size_t end = int(index_positive_.size()*count_/k_);
     45    for (size_t i=0; i<index_positive_.size(); i++)
     46      if (i<begin || i>=end)
     47        training_set.push_back(index_positive_[i]);
    3148   
    32     size_t begin = int(index_.size()*count_/k_);
    33     count_++;
    34     size_t end = int(index_.size()*count_/k_);
    35     std::vector<size_t> training_set;
    36     for (size_t i=0; i<index_.size(); i++)
     49    begin = int(index_negative_.size()*(count_-1)/k_);
     50    end = int(index_negative_.size()*count_/k_);
     51    for (size_t i=0; i<index_negative_.size(); i++)
    3752      if (i<begin || i>=end)
    38         training_set.push_back(index_[i]);
     53        training_set.push_back(index_negative_[i]);
     54   
    3955    return training_set ;
    4056   
  • trunk/src/CrossValidation.h

    r115 r123  
    66// C++ tools include
    77/////////////////////
     8#include "vector.h"
    89
    910// Standard C++ includes
     
    2324  public:
    2425    ///
    25     /// Constructor
     26    /// Constructor taking \a target and \k for k-fold cross validation
    2627    ///
    27     CrossValidation(const size_t, const size_t = 3);
     28    CrossValidation(const theplu::gslapi::vector& target, const size_t k = 3);
    2829
    2930    ///
    30     /// Function generating a training set
     31    /// Function generating a training set. This is done in a balanced
     32    /// way, meaning the proportions between the classes the
     33    /// trainingset is close to the proportions in the whole dataset.
    3134    ///
    3235    std::vector<size_t> next();
     
    3437  private:
    3538    int count_;
    36     std::vector<size_t> index_;
     39    std::vector<size_t> index_negative_;
     40    std::vector<size_t> index_positive_;
    3741    int k_;
    3842
Note: See TracChangeset for help on using the changeset viewer.