source: trunk/lib/classifier/CrossSplitter.cc @ 540

Last change on this file since 540 was 540, checked in by Peter, 16 years ago

added weights in CrossSplitter? - not supported in interface though.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 2.9 KB
Line 
1// $Id: CrossSplitter.cc 540 2006-03-05 15:02:11Z peter $
2
3
4#include <c++_tools/classifier/CrossSplitter.h>
5#include <c++_tools/classifier/DataLookup2D.h>
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/random/random.h>
8
9#include <algorithm>
10#include <cassert>
11#include <utility>
12#include <vector>
13
14namespace theplu {
15namespace classifier { 
16
17  CrossSplitter::CrossSplitter(const Target& target, const DataLookup2D& data, 
18                               const size_t N, const size_t k)
19    : k_(k), state_(0), target_(target)
20  { 
21    assert(target.size()>1);
22    assert(target.size()==data.columns());
23
24    build(target, N, k);
25
26    for (size_t i=0; i<N; i++){
27     
28      // Dynamically allocated. Must be deleted in destructor.
29      training_data_.push_back(data.training_data(training_index_[i]));
30      validation_data_.push_back(data.validation_data(training_index_[i], 
31                                                    validation_index_[i]));
32      training_weight_.push_back(new MatrixLookup(0,1));
33      validation_weight_.push_back(new MatrixLookup(0,1));
34
35
36      training_target_.push_back(Target(target,training_index_[i]));
37      validation_target_.push_back(Target(target,validation_index_[i]));
38    }
39    assert(training_data_.size()==N);
40    assert(training_weight_.size()==N);
41    assert(training_target_.size()==N);
42    assert(validation_data_.size()==N);
43    assert(validation_weight_.size()==N);
44    assert(validation_target_.size()==N);
45  }
46
47  CrossSplitter::~CrossSplitter()
48  {
49    assert(training_data_.size()==validation_data_.size());
50    for (size_t i=0; i<training_data_.size(); i++) 
51      delete training_data_[i];
52    for (size_t i=0; i<validation_data_.size(); i++) 
53      delete validation_data_[i];
54  }
55
56  void CrossSplitter::build(const Target& target, size_t N, size_t k)
57  {
58    std::vector<std::pair<size_t,size_t> > v;
59    for (size_t i=0; i<target.size(); i++)
60      v.push_back(std::make_pair(target(i),i));
61    // sorting with respect to class
62    std::sort(v.begin(),v.end());
63   
64    // my_begin[i] is index of first sample of class i
65    std::vector<size_t> my_begin;
66    my_begin.reserve(target.nof_classes());
67    my_begin.push_back(0);
68    for (size_t i=1; i<target.size(); i++)
69      while (v[i].first > my_begin.size()-1)
70        my_begin.push_back(i);
71    my_begin.push_back(target.size());
72
73    random::DiscreteUniform rnd;
74
75    for (size_t i=0; i<N; ) {
76      // shuffle indices within class each class
77      for (size_t j=0; j<target.nof_classes(); j++)
78        random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1],rnd);
79     
80      for (size_t part=0; part<k && i<N; i++, part++) {
81        std::vector<size_t> training_index;
82        std::vector<size_t> validation_index;
83        for (size_t j=0; j<v.size(); j++) {
84          if (j%k==part)
85            validation_index.push_back(v[j].second);
86          else
87            training_index.push_back(v[j].second);
88        }
89
90        training_index_.push_back(training_index);
91        validation_index_.push_back(validation_index);
92      }
93    }
94    assert(training_index_.size()==N);
95    assert(validation_index_.size()==N);
96}
97
98}} // of namespace classifier and namespace theplu
Note: See TracBrowser for help on using the repository browser.