source: trunk/lib/classifier/CrossSplitter.h @ 558

Last change on this file since 558 was 558, checked in by Peter, 16 years ago

ConsensusInputRanker? now supports weights

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.6 KB
Line 
1// $Id: CrossSplitter.h 558 2006-03-10 15:58:16Z peter $
2
3#ifndef _theplu_classifier_cross_splitter_
4#define _theplu_classifier_cross_splitter_
5
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/classifier/MatrixLookup.h>
8
9
10#include <cassert>
11#include <vector>
12
13namespace theplu {
14namespace classifier { 
15  class DataLookup2D;
16
17
18  ///
19  /// Class splitting a set into training set and validation set in a
20  /// crossvalidation manner. This is done in a balanced way, meaning
21  /// the proportions between the classes in the trainingset is close
22  /// to the proportions in the whole dataset. In the first \a k
23  /// rounds each sample is returned k-1 times, for next round the
24  /// samples are shuffled and... In total there are N partitions, in
25  /// other words, each sample is in validation roughly N/k
26  ///   
27
28  class CrossSplitter
29  {
30 
31  public:
32    ///
33    /// @brief Constructor
34    /// 
35    /// @parameter Target targets
36    /// @parameter data data to split up in validation and training.
37    /// @parameter N total number of partitions
38    /// @parameter k for k-fold crossvalidation
39    ///
40    CrossSplitter(const Target& target, const DataLookup2D& data, 
41                  const size_t N, const size_t k);
42
43    ///
44    /// @brief Constructor with weights
45    /// 
46    /// @parameter Target targets
47    /// @parameter data data to split up in validation and training.
48    /// @parameter weights accompanying data.
49    /// @parameter N total number of partitions
50    /// @parameter k for k-fold crossvalidation
51    ///
52    CrossSplitter(const Target& target, const DataLookup2D& data, 
53                  const MatrixLookup& weight,
54                  const size_t N, const size_t k);
55
56    ///
57    /// Destructor
58    ///
59    ~CrossSplitter();
60
61    ///
62    /// @return true if in a valid state
63    ///
64    inline bool more(void) const { return state_<size(); }
65
66    ///
67    /// Function turning the object to the next state.
68    ///
69    inline void next(void) { state_++; }
70
71    ///
72    /// rewind the sampler to initial state
73    ///
74    inline void reset(void) { state_=0; }
75
76
77    ///
78    /// @return number of partitions
79    ///
80    inline u_long size(void) const { return training_data_.size(); }
81
82    ///
83    /// @return the target for the total set
84    ///
85    inline const Target& target(void) const { return target_; }
86
87
88    ///
89    /// @return training data
90    ///
91    /// @note if state is invalid the result is undefined
92    ///
93    inline const DataLookup2D& training_data(void) const 
94    { assert(more()); return *(training_data_[state_]); } 
95
96    /// @return training index
97    ///
98    /// @note if state is invalid the result is undefined
99    ///
100    inline const std::vector<size_t>& training_index(void) const
101    { assert(more()); return training_index_[state_]; }
102
103
104    ///
105    /// @return training target
106    ///
107    /// @note if state is invalid the result is undefined
108    ///
109    inline const Target& training_target(void) const 
110    { assert(more()); return training_target_[state_]; }
111
112    ///
113    /// @return training data weight
114    ///
115    /// @note if state is invalid the result is undefined
116    ///
117    inline const DataLookup2D& training_weight(void) const 
118    { assert(more()); return *(training_weight_[state_]); } 
119
120    ///
121    /// @return validation data
122    ///
123    /// @note if state is invalid the result is undefined
124    ///
125    inline const DataLookup2D& validation_data(void) const
126    { assert(more()); return *(validation_data_[state_]); }
127
128
129    /// @return validation index
130    ///
131    /// @note if state is invalid the result is undefined
132    ///
133    inline const std::vector<size_t>& validation_index(void) const
134    { assert(more()); return validation_index_[state_]; }
135
136    ///
137    /// @return validation target
138    ///
139    /// @note if state is invalid the result is undefined
140    ///
141    inline const Target& validation_target(void) const 
142    { assert(more()); return validation_target_[state_]; }
143
144    ///
145    /// @return validation data weights
146    ///
147    /// @note if state is invalid the result is undefined
148    ///
149    inline const DataLookup2D& validation_weight(void) const 
150    { assert(more()); return *(validation_weight_[state_]); } 
151
152    ///
153    /// @return true if weighted
154    ///
155    inline bool weighted(void) const { return weighted_; }
156
157  private:
158    void build(const Target& target, size_t N, size_t k); 
159
160    const size_t k_;
161    u_long state_;
162    Target target_;
163    const bool weighted_;
164   
165    std::vector<const DataLookup2D*> training_data_;
166    std::vector<const MatrixLookup*> training_weight_;
167    std::vector<std::vector<size_t> > training_index_;
168    std::vector<Target> training_target_;
169
170    std::vector<const DataLookup2D*> validation_data_;
171    std::vector<const MatrixLookup*> validation_weight_;
172    std::vector<std::vector<size_t> > validation_index_;
173    std::vector<Target> validation_target_;
174  };
175
176}} // of namespace classifier and namespace theplu
177
178#endif
179
Note: See TracBrowser for help on using the repository browser.