source: trunk/lib/classifier/NCC.h @ 526

Last change on this file since 526 was 526, checked in by Markus Ringnér, 16 years ago

Fixed bug in tScore and in MatrixLookup?. Added support for scoring inputs in SupervisedClassifier? and for using this in training and prediction in NCC.

File size: 2.0 KB
Line 
1// $Id$
2
3#ifndef _theplu_classifier_ncc_
4#define _theplu_classifier_ncc_
5
6#include <c++_tools/gslapi/matrix.h>
7
8#include <c++_tools/classifier/SupervisedClassifier.h>
9
10#include <map>
11
12namespace theplu {
13
14  namespace statistics {
15    class Distance;
16    class Score;
17  }
18
19namespace classifier { 
20
21  class Target;
22  class DataLookup1D;
23  class DataLookup2D;
24  class MatrixLookup;
25
26  ///
27  /// Class for Nearest Centroid Classification.
28  ///
29
30  class NCC : public SupervisedClassifier
31  {
32 
33  public:
34    ///
35    /// Constructor taking the training data, the target vector and
36    /// the distance measure as input.
37    ///
38    NCC(const MatrixLookup&, const Target&, const statistics::Distance&);
39
40
41    ///
42    /// Constructor taking the training data, the target vector, the
43    /// distance measure, the score used to rank data inputs, and the
44    /// number of top ranked data inputs to use in the classification.
45    ///
46    NCC(const MatrixLookup&, const Target&, const statistics::Distance&, 
47        statistics::Score&, const size_t);
48
49    virtual ~NCC();
50
51    ///
52    /// @return the centroids for each class as columns in a matrix.
53    ///
54    const gslapi::matrix& centroids(void) const {return centroids_;}
55
56    inline SupervisedClassifier* 
57    make_classifier(const DataLookup2D&, const Target&) const;
58   
59    ///
60    /// Train the classifier using the training data. Centroids are
61    /// calculated for each class.
62    ///
63    /// @return true if training succedeed.
64    ///
65    bool train();
66
67
68    ///
69    /// Calculate the distance to each centroid for a test sample
70    ///
71    void predict(const DataLookup1D&, gslapi::vector&) const;
72   
73    ///
74    /// Calculate the distance to each centroid for test samples
75    ///
76    void predict(const DataLookup2D&, gslapi::matrix&) const;
77
78
79  private:
80    gslapi::matrix centroids_;
81    const statistics::Distance& distance_;                 
82    const MatrixLookup& matrix_;
83
84  };
85
86  ///
87  /// The output operator for the NCC class.
88  ///
89  //  std::ostream& operator<< (std::ostream&, const NCC&);
90 
91 
92}} // of namespace classifier and namespace theplu
93
94#endif
Note: See TracBrowser for help on using the repository browser.