source: trunk/lib/classifier/CrossSplitter.h @ 540

Last change on this file since 540 was 540, checked in by Peter, 16 years ago

added weights in CrossSplitter? - not supported in interface though.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 4.0 KB
Line 
1// $Id: CrossSplitter.h 540 2006-03-05 15:02:11Z peter $
2
3#ifndef _theplu_classifier_cross_splitter_
4#define _theplu_classifier_cross_splitter_
5
6#include <c++_tools/classifier/Target.h>
7#include <c++_tools/classifier/MatrixLookup.h>
8
9
10#include <cassert>
11#include <vector>
12
13namespace theplu {
14namespace classifier { 
15  class DataLookup2D;
16
17
18  ///
19  /// Class splitting a set into training set and validation set in a
20  /// crossvalidation manner. This is done in a balanced way, meaning
21  /// the proportions between the classes in the trainingset is close
22  /// to the proportions in the whole dataset. In the first \a k
23  /// rounds each sample is returned k-1 times, for next round the
24  /// samples are shuffled and... In total there are N partitions, in
25  /// other words, each sample is in validation roughly N/k
26  ///   
27
28  class CrossSplitter
29  {
30 
31  public:
32    ///
33    /// @brief Constructor
34    /// 
35    /// @parameter Target targets
36    /// @parameter data data to split up in validation and training.
37    /// @parameter N total number of partitions
38    /// @parameter k for k-fold crossvalidation
39    ///
40    CrossSplitter(const Target& target, const DataLookup2D& data, 
41                  const size_t N, const size_t k);
42
43    ///
44    /// @brief Constructor with weights
45    /// 
46    /// @parameter Target targets
47    /// @parameter data data to split up in validation and training.
48    /// @parameter weights accompanying data.
49    /// @parameter N total number of partitions
50    /// @parameter k for k-fold crossvalidation
51    ///
52    CrossSplitter(const Target& target, const DataLookup2D& data, 
53                  const DataLookup2D& weights,
54                  const size_t N, const size_t k);
55
56    ///
57    /// Destructor
58    ///
59    ~CrossSplitter();
60
61    ///
62    /// @return true if in a valid state
63    ///
64    inline bool more(void) const { return state_<size(); }
65
66    ///
67    /// Function turning the object to the next state.
68    ///
69    inline void next(void) { state_++; }
70
71    ///
72    /// rewind the sampler to initial state
73    ///
74    inline void reset(void) { state_=0; }
75
76
77    ///
78    /// @return number of partitions
79    ///
80    inline u_long size(void) const { return training_data_.size(); }
81
82    ///
83    /// @return the target for the total set
84    ///
85    inline const Target& target(void) const { return target_; }
86
87
88    ///
89    /// @return training data
90    ///
91    /// @note if state is invalid the result is undefined
92    ///
93    inline const DataLookup2D& training_data(void) const 
94    { assert(more()); return *(training_data_[state_]); } 
95
96    /// @return training index
97    ///
98    /// @note if state is invalid the result is undefined
99    ///
100    inline const std::vector<size_t>& training_index(void) const
101    { assert(more()); return training_index_[state_]; }
102
103
104    ///
105    /// @return training target
106    ///
107    /// @note if state is invalid the result is undefined
108    ///
109    inline const Target& training_target(void) const 
110    { assert(more()); return training_target_[state_]; }
111
112    ///
113    /// @return validation data
114    ///
115    /// @note if state is invalid the result is undefined
116    ///
117    inline const DataLookup2D& validation_data(void) const
118    { assert(more()); return *(validation_data_[state_]); }
119
120
121    /// @return validation index
122    ///
123    /// @note if state is invalid the result is undefined
124    ///
125    inline const std::vector<size_t>& validation_index(void) const
126    { assert(more()); return validation_index_[state_]; }
127
128    ///
129    /// @return validation target
130    ///
131    /// @note if state is invalid the result is undefined
132    ///
133    inline const Target& validation_target(void) const 
134    { assert(more()); return validation_target_[state_]; }
135
136  private:
137    void build(const Target& target, size_t N, size_t k); 
138
139    const size_t k_;
140    u_long state_;
141    Target target_;
142   
143    std::vector<const DataLookup2D*> training_data_;
144    std::vector<const MatrixLookup*> training_weight_;
145    std::vector<std::vector<size_t> > training_index_;
146    std::vector<Target> training_target_;
147
148    std::vector<const DataLookup2D*> validation_data_;
149    std::vector<const MatrixLookup*> validation_weight_;
150    std::vector<std::vector<size_t> > validation_index_;
151    std::vector<Target> validation_target_;
152  };
153
154}} // of namespace classifier and namespace theplu
155
156#endif
157
Note: See TracBrowser for help on using the repository browser.