source: trunk/c++_tools/classifier/NCC.h @ 593

Last change on this file since 593 was 593, checked in by Markus Ringnér, 15 years ago

Fixed std includes to compile with g++ 4.1.

File size: 2.7 KB
Line 
1// $Id$
2
3#ifndef _theplu_classifier_ncc_
4#define _theplu_classifier_ncc_
5
6#include <c++_tools/gslapi/matrix.h>
7
8#include <c++_tools/classifier/SupervisedClassifier.h>
9
10#include <map>
11
12namespace theplu {
13
14  namespace statistics {
15    class Distance;
16    class Score;
17  }
18
19namespace classifier { 
20
21  class CrossSplitter;
22  class DataLookup1D;
23  class DataLookup2D;
24  class MatrixLookup;
25  class Target;
26
27  ///
28  /// Class for Nearest Centroid Classification.
29  ///
30
31  class NCC : public SupervisedClassifier
32  {
33 
34  public:
35    ///
36    /// Constructor taking the training data, the target vector, and
37    /// the distance measure as input.
38    ///
39    NCC(const MatrixLookup&, const Target&, const statistics::Distance&);
40   
41    ///
42    /// Constructor taking the training data, the target vector, the
43    /// distance measure, and a weight matrix for the training data as
44    /// input.
45    ///
46    NCC(const MatrixLookup&, const Target&, const statistics::Distance&,
47        const MatrixLookup&);
48
49   
50
51    ///
52    /// Constructor taking the training data, the target vector, the
53    /// distance measure, the score used to rank data inputs, and the
54    /// number of top ranked data inputs to use in the classification
55    /// as input
56    ///
57    NCC(const MatrixLookup&, const Target&, const statistics::Distance&, 
58        statistics::Score&, const size_t);
59
60    ///
61    /// Constructor taking the training data, the target vector, the
62    /// distance measure, a weight matrix for the training data, the
63    /// score used to rank data inputs, and the number of top ranked
64    /// data inputs to use in the classification as input
65    ///
66    NCC(const MatrixLookup&, const Target&, const statistics::Distance&, 
67        const MatrixLookup&, statistics::Score&, const size_t);
68
69
70    virtual ~NCC();
71
72    ///
73    /// @return the centroids for each class as columns in a matrix.
74    ///
75    const gslapi::matrix& centroids(void) const {return centroids_;}
76
77    inline SupervisedClassifier* 
78    make_classifier(const CrossSplitter&) const;
79   
80    ///
81    /// Train the classifier using the training data. Centroids are
82    /// calculated for each class.
83    ///
84    /// @return true if training succedeed.
85    ///
86    bool train();
87
88
89    ///
90    /// Calculate the distance to each centroid for a test sample
91    ///
92    void predict(const DataLookup1D&, gslapi::vector&) const;
93   
94    ///
95    /// Calculate the distance to each centroid for test samples
96    ///
97    void predict(const DataLookup2D&, gslapi::matrix&) const;
98
99
100  private:
101    gslapi::matrix centroids_;
102    const statistics::Distance& distance_;                 
103    const MatrixLookup& matrix_;
104    bool weighted_;
105    const MatrixLookup* weights_;
106  };
107
108  ///
109  /// The output operator for the NCC class.
110  ///
111  //  std::ostream& operator<< (std::ostream&, const NCC&);
112 
113 
114}} // of namespace classifier and namespace theplu
115
116#endif
Note: See TracBrowser for help on using the repository browser.