source: trunk/yat/classifier/NCC.h @ 1050

Last change on this file since 1050 was 1050, checked in by Peter, 15 years ago

Simplifying distance structure

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 8.3 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "DataLookup1D.h"
30#include "DataLookup2D.h"
31#include "DataLookupWeighted1D.h"
32#include "MatrixLookup.h"
33#include "MatrixLookupWeighted.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36
37#include "yat/statistics/Averager.h"
38#include "yat/statistics/AveragerWeighted.h"
39#include "yat/utility/Iterator.h"
40#include "yat/utility/IteratorWeighted.h"
41#include "yat/utility/matrix.h"
42#include "yat/utility/vector.h"
43#include "yat/utility/stl_utility.h"
44#include "yat/utility/yat_assert.h"
45
46#include<iostream>
47#include<iterator>
48#include <map>
49#include <cmath>
50#include <stdexcept>
51
52namespace theplu {
53namespace yat {
54namespace classifier { 
55
56
57  ///
58  /// @brief Class for Nearest Centroid Classification.
59  ///
60
61  template <typename Distance>
62  class NCC : public SupervisedClassifier
63  {
64 
65  public:
66    ///
67    /// Constructor taking the training data and the target vector as
68    /// input
69    ///
70    NCC(const MatrixLookup&, const Target&);
71   
72    ///
73    /// Constructor taking the training data with weights and the
74    /// target vector as input.
75    ///
76    NCC(const MatrixLookupWeighted&, const Target&);
77
78    virtual ~NCC();
79
80    ///
81    /// @return the centroids for each class as columns in a matrix.
82    ///
83    const utility::matrix& centroids(void) const;
84
85    const DataLookup2D& data(void) const;
86
87    SupervisedClassifier* make_classifier(const DataLookup2D&, 
88                                          const Target&) const;
89   
90    ///
91    /// Train the classifier using the training data. Centroids are
92    /// calculated for each class.
93    ///
94    /// @return true if training succedeed.
95    ///
96    void train();
97
98   
99    ///
100    /// Calculate the distance to each centroid for test samples
101    ///
102    void predict(const DataLookup2D&, utility::matrix&) const;
103   
104   
105  private:
106
107    void predict_unweighted(const MatrixLookup&, utility::matrix&) const;
108    void predict_weighted(const MatrixLookupWeighted&, utility::matrix&) const;   
109
110    utility::matrix* centroids_;
111    bool centroids_nan_;
112    Distance distance_;
113
114    // data_ has to be of type DataLookup2D to accomodate both
115    // MatrixLookup and MatrixLookupWeighted
116    const DataLookup2D& data_;
117  };
118
119  ///
120  /// The output operator for the NCC class.
121  ///
122  //  std::ostream& operator<< (std::ostream&, const NCC&);
123 
124
125  // templates
126
127  template <typename Distance>
128  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 
129    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data) 
130  {
131  }
132
133  template <typename Distance>
134  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
135    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
136  {
137  }
138
139  template <typename Distance>
140  NCC<Distance>::~NCC()   
141  {
142    if(centroids_)
143      delete centroids_;
144  }
145
146  template <typename Distance>
147  const utility::matrix& NCC<Distance>::centroids(void) const
148  {
149    return *centroids_;
150  }
151 
152
153  template <typename Distance>
154  const DataLookup2D& NCC<Distance>::data(void) const
155  {
156    return data_;
157  }
158 
159  template <typename Distance>
160  SupervisedClassifier* 
161  NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
162  {     
163    NCC* ncc=0;
164    try {
165      if(data.weighted()) {
166        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
167                              target);
168      }
169      else {
170        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
171                              target);
172      }
173    }
174    catch (std::bad_cast) {
175      std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
176      throw std::runtime_error(str);
177    }
178    return ncc;
179  }
180
181
182  template <typename Distance>
183  void NCC<Distance>::train()
184  {   
185    if(centroids_) 
186      delete centroids_;
187    centroids_= new utility::matrix(data_.rows(), target_.nof_classes());
188    // data_ is a MatrixLookup or a MatrixLookupWeighted
189    if(data_.weighted()) {
190      const MatrixLookupWeighted* weighted_data = 
191        dynamic_cast<const MatrixLookupWeighted*>(&data_);     
192      for(size_t i=0; i<data_.rows(); i++) {
193        std::vector<statistics::AveragerWeighted> class_averager;
194        class_averager.resize(target_.nof_classes());
195        for(size_t j=0; j<data_.columns(); j++) {
196          class_averager[target_(j)].add(weighted_data->data(i,j),
197                                         weighted_data->weight(i,j));
198        }
199        for(size_t c=0;c<target_.nof_classes();c++) {
200          (*centroids_)(i,c) = class_averager[c].mean();
201          if(class_averager[c].sum_w()==0)
202            centroids_nan_=true;
203        }
204      }
205    }
206    else {
207      const MatrixLookup* unweighted_data = 
208        dynamic_cast<const MatrixLookup*>(&data_);     
209      for(size_t i=0; i<data_.rows(); i++) {
210        std::vector<statistics::Averager> class_averager;
211        class_averager.resize(target_.nof_classes());
212        for(size_t j=0; j<data_.columns(); j++) {
213          class_averager[target_(j)].add((*unweighted_data)(i,j));
214        }
215        for(size_t c=0;c<target_.nof_classes();c++) {
216          (*centroids_)(i,c) = class_averager[c].mean();
217        }
218      }
219    }
220  }
221
222  template <typename Distance>
223  void NCC<Distance>::predict(const DataLookup2D& test,                     
224                              utility::matrix& prediction) const
225  {   
226    utility::yat_assert<std::runtime_error>
227      (centroids_,"NCC::predict called for untrained classifier");
228    utility::yat_assert<std::runtime_error>
229      (data_.rows()==test.rows(),
230       "NCC::predict test data with incorrect number of rows");
231   
232    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
233
234    // unweighted test data
235    if (const MatrixLookup* test_unweighted =
236        dynamic_cast<const MatrixLookup*>(&test)) {
237      // If weighted training data resulting in NaN in centroids: weighted calculations
238      if(centroids_nan_) { 
239        //        predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction);
240        std::string str = 
241        "Error in NCC<Distance>::predict: weighted training unweighted test not implemented yet";
242      throw std::runtime_error(str);
243      }
244      // If unweighted training data: unweighted calculations
245      else {
246        predict_unweighted(*test_unweighted,prediction);
247      }
248    }
249    // weighted test data: weighted calculations
250    else if (const MatrixLookupWeighted* test_weighted =
251             dynamic_cast<const MatrixLookupWeighted*>(&test)) { 
252      predict_weighted(*test_weighted,prediction);
253    }
254    else {
255      std::string str = 
256        "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
257      throw std::runtime_error(str);
258    }
259  }
260 
261  template <typename Distance>
262  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
263                                         utility::matrix& prediction) const
264  {
265    MatrixLookup unweighted_centroids(*centroids_);
266    for(size_t j=0; j<test.columns();j++) {       
267      DataLookup1D in(test,j,false);
268      for(size_t k=0; k<centroids_->columns();k++) {
269        DataLookup1D centroid(unweighted_centroids,k,false);           
270        utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
271        prediction(k,j) = distance_(in.begin(), in.end(), centroid.begin());
272      }
273    }
274  }
275
276  template <typename Distance>
277  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
278                                          utility::matrix& prediction) const
279  {
280    MatrixLookupWeighted weighted_centroids(*centroids_);
281    for(size_t j=0; j<test.columns();j++) {       
282      DataLookupWeighted1D in(test,j,false);
283      for(size_t k=0; k<centroids_->columns();k++) {
284        DataLookupWeighted1D centroid(weighted_centroids,k,false);
285        utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
286        prediction(k,j) = distance_(in.begin(), in.end(), centroid.begin());
287      }
288    }
289  }
290
291     
292}}} // of namespace classifier, yat, and theplu
293
294#endif
Note: See TracBrowser for help on using the repository browser.