source: trunk/yat/classifier/NCC.h @ 1142

Last change on this file since 1142 was 1142, checked in by Markus Ringnér, 14 years ago

Refs #335, fixed for NCC, working on KNN

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.9 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "DataLookup1D.h"
30#include "DataLookup2D.h"
31#include "DataLookupWeighted1D.h"
32#include "MatrixLookup.h"
33#include "MatrixLookupWeighted.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36
37#include "yat/statistics/Averager.h"
38#include "yat/statistics/AveragerWeighted.h"
39#include "yat/utility/Matrix.h"
40#include "yat/utility/Vector.h"
41#include "yat/utility/stl_utility.h"
42#include "yat/utility/yat_assert.h"
43
44#include<iostream>
45#include<iterator>
46#include <map>
47#include <cmath>
48#include <stdexcept>
49
50namespace theplu {
51namespace yat {
52namespace classifier { 
53
54
55  ///
56  /// @brief Class for Nearest Centroid Classification.
57  ///
58  /// The template argument Distance should be a class modelling
59  /// the concept \ref concept_distance.
60  ///
61  template <typename Distance>
62  class NCC : public SupervisedClassifier
63  {
64 
65  public:
66    ///
67    /// Constructor taking the training data and the target vector as
68    /// input
69    ///
70    NCC(const MatrixLookup&, const Target&);
71   
72    ///
73    /// Constructor taking the training data with weights and the
74    /// target vector as input.
75    ///
76    NCC(const MatrixLookupWeighted&, const Target&);
77
78    virtual ~NCC();
79
80    ///
81    /// @return the centroids for each class as columns in a matrix.
82    ///
83    const utility::Matrix& centroids(void) const;
84
85    const DataLookup2D& data(void) const;
86
87    SupervisedClassifier* make_classifier(const DataLookup2D&, 
88                                          const Target&) const;
89   
90    ///
91    /// Train the classifier using the training data. Centroids are
92    /// calculated for each class.
93    ///
94    void train();
95
96   
97    ///
98    /// Calculate the distance to each centroid for test samples
99    ///
100    void predict(const DataLookup2D&, utility::Matrix&) const;
101   
102   
103  private:
104
105    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
106    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
107
108    utility::Matrix* centroids_;
109    bool centroids_nan_;
110    Distance distance_;
111
112    // data_ has to be of type DataLookup2D to accomodate both
113    // MatrixLookup and MatrixLookupWeighted
114    const DataLookup2D& data_;
115  };
116
117  ///
118  /// The output operator for the NCC class.
119  ///
120  //  std::ostream& operator<< (std::ostream&, const NCC&);
121 
122
123  // templates
124
125  template <typename Distance>
126  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 
127    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data) 
128  {
129  }
130
131  template <typename Distance>
132  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
133    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
134  {
135  }
136
137  template <typename Distance>
138  NCC<Distance>::~NCC()   
139  {
140    if(centroids_)
141      delete centroids_;
142  }
143
144  template <typename Distance>
145  const utility::Matrix& NCC<Distance>::centroids(void) const
146  {
147    return *centroids_;
148  }
149 
150
151  template <typename Distance>
152  const DataLookup2D& NCC<Distance>::data(void) const
153  {
154    return data_;
155  }
156 
157  template <typename Distance>
158  SupervisedClassifier* 
159  NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
160  {     
161    NCC* ncc=0;
162    try {
163      if(data.weighted()) {
164        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
165                              target);
166      }
167      else {
168        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
169                              target);
170      }
171    }
172    catch (std::bad_cast) {
173      std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
174      throw std::runtime_error(str);
175    }
176    return ncc;
177  }
178
179
180  template <typename Distance>
181  void NCC<Distance>::train()
182  {   
183    if(centroids_) 
184      delete centroids_;
185    centroids_= new utility::Matrix(data_.rows(), target_.nof_classes());
186    // data_ is a MatrixLookup or a MatrixLookupWeighted
187    if(data_.weighted()) {
188      const MatrixLookupWeighted* weighted_data = 
189        dynamic_cast<const MatrixLookupWeighted*>(&data_);     
190      for(size_t i=0; i<data_.rows(); i++) {
191        std::vector<statistics::AveragerWeighted> class_averager;
192        class_averager.resize(target_.nof_classes());
193        for(size_t j=0; j<data_.columns(); j++) {
194          class_averager[target_(j)].add(weighted_data->data(i,j),
195                                         weighted_data->weight(i,j));
196        }
197        for(size_t c=0;c<target_.nof_classes();c++) {
198          if(class_averager[c].sum_w()==0) {
199            centroids_nan_=true;
200          }
201          (*centroids_)(i,c) = class_averager[c].mean();
202        }
203      }
204    }
205    else {
206      const MatrixLookup* unweighted_data = 
207        dynamic_cast<const MatrixLookup*>(&data_);     
208      for(size_t i=0; i<data_.rows(); i++) {
209        std::vector<statistics::Averager> class_averager;
210        class_averager.resize(target_.nof_classes());
211        for(size_t j=0; j<data_.columns(); j++) {
212          class_averager[target_(j)].add((*unweighted_data)(i,j));
213        }
214        for(size_t c=0;c<target_.nof_classes();c++) {
215          (*centroids_)(i,c) = class_averager[c].mean();
216        }
217      }
218    }
219  }
220
221  template <typename Distance>
222  void NCC<Distance>::predict(const DataLookup2D& test,                     
223                              utility::Matrix& prediction) const
224  {   
225    utility::yat_assert<std::runtime_error>
226      (centroids_,"NCC::predict called for untrained classifier");
227    utility::yat_assert<std::runtime_error>
228      (data_.rows()==test.rows(),
229       "NCC::predict test data with incorrect number of rows");
230   
231    prediction.resize(centroids_->columns(), test.columns());
232
233    // unweighted test data
234    if (const MatrixLookup* test_unweighted =
235        dynamic_cast<const MatrixLookup*>(&test)) {
236      // If weighted training data has resulted in NaN in centroids: weighted calculations
237      if(centroids_nan_) { 
238        predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction);
239      }
240      // If unweighted training data: unweighted calculations
241      else {
242        predict_unweighted(*test_unweighted,prediction);
243      }
244    }
245    // weighted test data: weighted calculations
246    else if (const MatrixLookupWeighted* test_weighted =
247             dynamic_cast<const MatrixLookupWeighted*>(&test)) { 
248      predict_weighted(*test_weighted,prediction);
249    }
250    else {
251      std::string str = 
252        "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
253      throw std::runtime_error(str);
254    }
255  }
256 
257  template <typename Distance>
258  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
259                                         utility::Matrix& prediction) const
260  {
261    MatrixLookup centroids(*centroids_);
262    for(size_t j=0; j<test.columns();j++)
263      for(size_t k=0; k<centroids_->columns();k++) 
264        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
265                                    centroids.begin_column(k));
266  }
267 
268  template <typename Distance>
269  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
270                                          utility::Matrix& prediction) const
271  {
272    MatrixLookupWeighted weighted_centroids(*centroids_);
273    for(size_t j=0; j<test.columns();j++) 
274      for(size_t k=0; k<centroids_->columns();k++)
275        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
276                                    weighted_centroids.begin_column(k));
277  }
278
279     
280}}} // of namespace classifier, yat, and theplu
281
282#endif
Note: See TracBrowser for help on using the repository browser.