source: trunk/lib/classifier/CrossSplitter.h @ 514

Last change on this file since 514 was 514, checked in by Peter, 17 years ago

generalised binary functionality in Target

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 3.4 KB
Line 
1// $Id: CrossSplitter.h 514 2006-02-20 09:45:34Z peter $
2
3#ifndef _theplu_classifier_cross_splitter_
4#define _theplu_classifier_cross_splitter_
5
6#include <c++_tools/classifier/Target.h>
7
8
9#include <cassert>
10#include <vector>
11
12namespace theplu {
13namespace classifier { 
14  class DataLookup2D;
15
16
17  ///
18  /// Class splitting a set into training set and validation set in a
19  /// crossvalidation manner. This is done in a balanced way, meaning
20  /// the proportions between the classes in the trainingset is close
21  /// to the proportions in the whole dataset. In the first \a k
22  /// rounds each sample is returned k-1 times, for next round the
23  /// samples are shuffled and... In total there are N partitions, in
24  /// other words, each sample is in validation roughly N/k
25  ///   
26
27  class CrossSplitter
28  {
29 
30  public:
31    ///
32    /// @brief Constructor
33    /// 
34    /// @parameter Target targets
35    /// @parameter data data to split up in validation and training.
36    /// @parameter N total number of partitions
37    /// @parameter k for k-fold crossvalidation
38    ///
39    CrossSplitter(const Target& target, const DataLookup2D& data, 
40                  const size_t N, const size_t k);
41
42    ///
43    /// Destructor
44    ///
45    ~CrossSplitter();
46
47    ///
48    /// @return true if in a valid state
49    ///
50    inline bool more(void) const { return state_<size(); }
51
52    ///
53    /// Function turning the object to the next state.
54    ///
55    inline void next(void) { state_++; }
56
57    ///
58    /// rewind the sampler to initial state
59    ///
60    inline void reset(void) { state_=0; }
61
62
63    ///
64    /// @return number of partitions
65    ///
66    inline u_long size(void) const { return training_data_.size(); }
67
68    ///
69    /// @return the target for the total set
70    ///
71    inline const Target& target(void) const { return target_; }
72
73
74    ///
75    /// @return training data
76    ///
77    /// @note if state is invalid the result is undefined
78    ///
79    inline const DataLookup2D& training_data(void) const 
80    { assert(more()); return *(training_data_[state_]); } 
81
82    /// @return training index
83    ///
84    /// @note if state is invalid the result is undefined
85    ///
86    inline const std::vector<size_t>& training_index(void) const
87    { assert(more()); return training_index_[state_]; }
88
89
90    ///
91    /// @return training target
92    ///
93    /// @note if state is invalid the result is undefined
94    ///
95    inline const Target& training_target(void) const 
96    { assert(more()); return training_target_[state_]; }
97
98    ///
99    /// @return validation data
100    ///
101    /// @note if state is invalid the result is undefined
102    ///
103    inline const DataLookup2D& validation_data(void) const
104    { assert(more()); return *(validation_data_[state_]); }
105
106
107    /// @return validation index
108    ///
109    /// @note if state is invalid the result is undefined
110    ///
111    inline const std::vector<size_t>& validation_index(void) const
112    { assert(more()); return validation_index_[state_]; }
113
114    ///
115    /// @return validation target
116    ///
117    /// @note if state is invalid the result is undefined
118    ///
119    inline const Target& validation_target(void) const 
120    { assert(more()); return validation_target_[state_]; }
121
122  private:
123    const size_t k_;
124    u_long state_;
125    Target target_;
126   
127    std::vector<const DataLookup2D*> training_data_;
128    std::vector<std::vector<size_t> > training_index_;
129    std::vector<Target> training_target_;
130
131    std::vector<const DataLookup2D*> validation_data_;
132    std::vector<std::vector<size_t> > validation_index_;
133    std::vector<Target> validation_target_;
134  };
135
136}} // of namespace classifier and namespace theplu
137
138#endif
139
Note: See TracBrowser for help on using the repository browser.