source: trunk/yat/classifier/NCC.h @ 1013

Last change on this file since 1013 was 1013, checked in by Markus Ringnér, 14 years ago

Adding functionality tests for NCC

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.5 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "DataLookup1D.h"
30#include "DataLookup2D.h"
31#include "DataLookupWeighted1D.h"
32#include "MatrixLookup.h"
33#include "MatrixLookupWeighted.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36
37#include "yat/statistics/Averager.h"
38#include "yat/statistics/AveragerWeighted.h"
39#include "yat/statistics/vector_distance.h"
40
41#include "yat/utility/Iterator.h"
42#include "yat/utility/IteratorWeighted.h"
43#include "yat/utility/matrix.h"
44#include "yat/utility/vector.h"
45#include "yat/utility/stl_utility.h"
46#include "yat/utility/yat_assert.h"
47
48#include<iostream>
49#include<iterator>
50#include <map>
51#include <cmath>
52#include <stdexcept>
53
54namespace theplu {
55namespace yat {
56namespace classifier { 
57
58
59  ///
60  /// @brief Class for Nearest Centroid Classification.
61  ///
62
63  template <typename Distance>
64  class NCC : public SupervisedClassifier
65  {
66 
67  public:
68    ///
69    /// Constructor taking the training data and the target vector as
70    /// input
71    ///
72    NCC(const MatrixLookup&, const Target&);
73   
74    ///
75    /// Constructor taking the training data with weights and the
76    /// target vector as input.
77    ///
78    NCC(const MatrixLookupWeighted&, const Target&);
79
80    virtual ~NCC();
81
82    ///
83    /// @return the centroids for each class as columns in a matrix.
84    ///
85    const utility::matrix& centroids(void) const;
86
87    const DataLookup2D& data(void) const;
88
89    SupervisedClassifier* make_classifier(const DataLookup2D&, 
90                                          const Target&) const;
91   
92    ///
93    /// Train the classifier using the training data. Centroids are
94    /// calculated for each class.
95    ///
96    /// @return true if training succedeed.
97    ///
98    bool train();
99
100   
101    ///
102    /// Calculate the distance to each centroid for test samples
103    ///
104    void predict(const DataLookup2D&, utility::matrix&) const;
105   
106   
107  private:
108
109    utility::matrix* centroids_;
110
111    // data_ has to be of type DataLookup2D to accomodate both
112    // MatrixLookup and MatrixLookupWeighted
113    const DataLookup2D& data_;
114
115  };
116
117  ///
118  /// The output operator for the NCC class.
119  ///
120  //  std::ostream& operator<< (std::ostream&, const NCC&);
121 
122
123  // templates
124
125  template <typename Distance>
126  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 
127    : SupervisedClassifier(target), centroids_(0), data_(data) 
128  {
129  }
130
131  template <typename Distance>
132  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
133    : SupervisedClassifier(target), centroids_(0), data_(data) 
134  {
135  }
136
137  template <typename Distance>
138  NCC<Distance>::~NCC()   
139  {
140    if(centroids_)
141      delete centroids_;
142  }
143
144  template <typename Distance>
145  const utility::matrix& NCC<Distance>::centroids(void) const
146  {
147    return *centroids_;
148  }
149 
150
151  template <typename Distance>
152  const DataLookup2D& NCC<Distance>::data(void) const
153  {
154    return data_;
155  }
156 
157  template <typename Distance>
158  SupervisedClassifier* 
159  NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
160  {     
161    NCC* ncc=0;
162    try {
163      if(data.weighted()) {
164        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
165                              target);
166      }
167      else {
168        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
169                              target);
170      }
171      ncc->centroids_=0;
172    }
173    catch (std::bad_cast) {
174      std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
175      throw std::runtime_error(str);
176    }
177    return ncc;
178  }
179
180
181  template <typename Distance>
182  bool NCC<Distance>::train()
183  {   
184    if(centroids_) 
185      delete centroids_;
186    centroids_= new utility::matrix(data_.rows(), target_.nof_classes());
187    // data_ is a MatrixLookup or a MatrixLookupWeighted
188    if(data_.weighted()) {
189      const MatrixLookupWeighted* weighted_data = 
190        dynamic_cast<const MatrixLookupWeighted*>(&data_);     
191      for(size_t i=0; i<data_.rows(); i++) {
192        std::vector<statistics::AveragerWeighted> class_averager;
193        class_averager.resize(target_.nof_classes());
194        for(size_t j=0; j<data_.columns(); j++) {
195          class_averager[target_(j)].add(weighted_data->data(i,j),
196                                         weighted_data->weight(i,j));
197        }
198        for(size_t c=0;c<target_.nof_classes();c++) {
199          (*centroids_)(i,c) = class_averager[c].mean();
200        }
201      }
202    }
203    else {
204      const MatrixLookup* unweighted_data = 
205        dynamic_cast<const MatrixLookup*>(&data_);     
206      for(size_t i=0; i<data_.rows(); i++) {
207        std::vector<statistics::Averager> class_averager;
208        class_averager.resize(target_.nof_classes());
209        for(size_t j=0; j<data_.columns(); j++) {
210          class_averager[target_(j)].add((*unweighted_data)(i,j));
211        }
212        for(size_t c=0;c<target_.nof_classes();c++) {
213          (*centroids_)(i,c) = class_averager[c].mean();
214        }
215      }
216    }
217    return true;
218  }
219
220  template <typename Distance>
221  void NCC<Distance>::predict(const DataLookup2D& test,                     
222                              utility::matrix& prediction) const
223  {   
224    utility::yat_assert<std::runtime_error>
225      (centroids_,"NCC::predict called for untrained classifier");
226    utility::yat_assert<std::runtime_error>
227      (data_.rows()==test.rows(),
228       "NCC::predict test data with incorrect number of rows");
229   
230    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
231
232    // unweighted test data
233    if (const MatrixLookup* test_unweighted =
234        dynamic_cast<const MatrixLookup*>(&test)) {
235      MatrixLookup unweighted_centroids(*centroids_);
236      for(size_t j=0; j<test.columns();j++) {       
237        DataLookup1D in(*test_unweighted,j,false);
238        for(size_t k=0; k<centroids_->columns();k++) {
239          DataLookup1D centroid(unweighted_centroids,k,false);           
240          utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
241          prediction(k,j)=statistics::
242            vector_distance(in.begin(),in.end(),centroid.begin(),
243                            typename statistics::vector_distance_traits<Distance>::distance());
244        }
245      }
246    }
247    // weighted test data
248    else if (const MatrixLookupWeighted* test_weighted =
249            dynamic_cast<const MatrixLookupWeighted*>(&test)) { 
250      MatrixLookupWeighted weighted_centroids(*centroids_);
251      for(size_t j=0; j<test.columns();j++) {       
252        DataLookupWeighted1D in(*test_weighted,j,false);
253        for(size_t k=0; k<centroids_->columns();k++) {
254          DataLookupWeighted1D centroid(weighted_centroids,k,false);
255          utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
256          prediction(k,j)=statistics::
257            vector_distance(in.begin(),in.end(),centroid.begin(),
258                            typename statistics::vector_distance_traits<Distance>::distance());
259        }
260      }
261    }
262    else {
263      std::string str = 
264        "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
265      throw std::runtime_error(str);
266    }
267  }
268     
269}}} // of namespace classifier, yat, and theplu
270
271#endif
Note: See TracBrowser for help on using the repository browser.