source: trunk/yat/classifier/NCC.h @ 1158

Last change on this file since 1158 was 1158, checked in by Markus Ringnér, 14 years ago

Fixed #322

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.0 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "DataLookup1D.h"
30#include "DataLookup2D.h"
31#include "DataLookupWeighted1D.h"
32#include "MatrixLookup.h"
33#include "MatrixLookupWeighted.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36
37#include "yat/statistics/Averager.h"
38#include "yat/statistics/AveragerWeighted.h"
39#include "yat/utility/Matrix.h"
40#include "yat/utility/Vector.h"
41#include "yat/utility/stl_utility.h"
42#include "yat/utility/yat_assert.h"
43
44#include<iostream>
45#include<iterator>
46#include <map>
47#include <cmath>
48#include <stdexcept>
49
50namespace theplu {
51namespace yat {
52namespace classifier { 
53
54
55  ///
56  /// @brief Class for Nearest Centroid Classification.
57  ///
58  /// The template argument Distance should be a class modelling
59  /// the concept \ref concept_distance.
60  ///
61  template <typename Distance>
62  class NCC : public SupervisedClassifier
63  {
64 
65  public:
66    ///
67    /// @brief Constructor
68    ///
69    NCC(void);
70   
71    ///
72    /// @brief Constructor
73    ///
74    NCC(const Distance&);
75
76
77    ///
78    /// @brief Destructor
79    ///
80    virtual ~NCC(void);
81
82    ///
83    /// @return the centroids for each class as columns in a matrix.
84    ///
85    const utility::Matrix& centroids(void) const;
86
87    NCC<Distance>* make_classifier(void) const;
88   
89    ///
90    /// Train the classifier with a training data set and
91    /// targets. Centroids are calculated for each class.
92    ///
93    void train(const MatrixLookup&, const Target&);
94
95
96    ///
97    /// Train the classifier with a weighted training data set and
98    /// targets. Centroids are calculated for each class.
99    ///
100    void train(const MatrixLookupWeighted&, const Target&);
101
102   
103    ///
104    /// Calculate the distance to each centroid for test samples
105    ///
106    void predict(const DataLookup2D&, utility::Matrix&) const;
107   
108   
109  private:
110
111    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
112    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
113
114    utility::Matrix* centroids_;
115    bool centroids_nan_;
116    Distance distance_;
117  };
118
119  ///
120  /// The output operator for the NCC class.
121  ///
122  //  std::ostream& operator<< (std::ostream&, const NCC&);
123 
124
125  // templates
126
127  template <typename Distance>
128  NCC<Distance>::NCC() 
129    : SupervisedClassifier(), centroids_(0), centroids_nan_(false)
130  {
131  }
132
133  template <typename Distance>
134  NCC<Distance>::NCC(const Distance& dist) 
135    : SupervisedClassifier(), centroids_(0), centroids_nan_(false), distance_(dist)
136  {
137  }
138
139
140  template <typename Distance>
141  NCC<Distance>::~NCC()   
142  {
143    if(centroids_)
144      delete centroids_;
145  }
146
147
148  template <typename Distance>
149  const utility::Matrix& NCC<Distance>::centroids(void) const
150  {
151    return *centroids_;
152  }
153 
154
155  template <typename Distance>
156  NCC<Distance>* 
157  NCC<Distance>::make_classifier() const 
158  {     
159    return new NCC<Distance>();
160  }
161
162  template <typename Distance>
163  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
164  {   
165    if(centroids_) 
166      delete centroids_;
167    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
168    for(size_t i=0; i<data.rows(); i++) {
169      std::vector<statistics::Averager> class_averager;
170      class_averager.resize(target.nof_classes());
171      for(size_t j=0; j<data.columns(); j++) {
172        class_averager[target(j)].add(data(i,j));
173      }
174      for(size_t c=0;c<target.nof_classes();c++) {
175        (*centroids_)(i,c) = class_averager[c].mean();
176      }
177    }
178    trained_=true;
179  }
180
181
182  template <typename Distance>
183  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
184  {   
185    if(centroids_) 
186      delete centroids_;
187    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
188    for(size_t i=0; i<data.rows(); i++) {
189      std::vector<statistics::AveragerWeighted> class_averager;
190      class_averager.resize(target.nof_classes());
191      for(size_t j=0; j<data.columns(); j++) 
192        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
193      for(size_t c=0;c<target.nof_classes();c++) {
194        if(class_averager[c].sum_w()==0) {
195          centroids_nan_=true;
196        }
197        (*centroids_)(i,c) = class_averager[c].mean();
198      }
199    }
200    trained_=true;
201  }
202
203
204  template <typename Distance>
205  void NCC<Distance>::predict(const DataLookup2D& test,                     
206                              utility::Matrix& prediction) const
207  {   
208    utility::yat_assert<std::runtime_error>
209      (centroids_,"NCC::predict called for untrained classifier");
210    utility::yat_assert<std::runtime_error>
211      (centroids_->rows()==test.rows(),
212       "NCC::predict test data with incorrect number of rows");
213   
214    prediction.resize(centroids_->columns(), test.columns());
215
216    // unweighted test data
217    if (const MatrixLookup* test_unweighted =
218        dynamic_cast<const MatrixLookup*>(&test)) {
219      // If weighted training data has resulted in NaN in centroids: weighted calculations
220      if(centroids_nan_) { 
221        predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction);
222      }
223      // If unweighted training data: unweighted calculations
224      else {
225        predict_unweighted(*test_unweighted,prediction);
226      }
227    }
228    // weighted test data: weighted calculations
229    else if (const MatrixLookupWeighted* test_weighted =
230             dynamic_cast<const MatrixLookupWeighted*>(&test)) { 
231      predict_weighted(*test_weighted,prediction);
232    }
233    else {
234      std::string str = 
235        "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
236      throw std::runtime_error(str);
237    }
238  }
239 
240  template <typename Distance>
241  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
242                                         utility::Matrix& prediction) const
243  {
244    MatrixLookup centroids(*centroids_);
245    for(size_t j=0; j<test.columns();j++)
246      for(size_t k=0; k<centroids_->columns();k++) 
247        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
248                                    centroids.begin_column(k));
249  }
250 
251  template <typename Distance>
252  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
253                                          utility::Matrix& prediction) const
254  {
255    MatrixLookupWeighted weighted_centroids(*centroids_);
256    for(size_t j=0; j<test.columns();j++) 
257      for(size_t k=0; k<centroids_->columns();k++)
258        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
259                                    weighted_centroids.begin_column(k));
260  }
261
262     
263}}} // of namespace classifier, yat, and theplu
264
265#endif
Note: See TracBrowser for help on using the repository browser.