source: trunk/yat/classifier/NCC.h @ 1162

Last change on this file since 1162 was 1162, checked in by Peter, 14 years ago

removing trained_

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 7.1 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "MatrixLookup.h"
30#include "MatrixLookupWeighted.h"
31#include "SupervisedClassifier.h"
32#include "Target.h"
33
34#include "yat/statistics/Averager.h"
35#include "yat/statistics/AveragerWeighted.h"
36#include "yat/utility/Matrix.h"
37#include "yat/utility/Vector.h"
38#include "yat/utility/stl_utility.h"
39#include "yat/utility/yat_assert.h"
40
41#include<iostream>
42#include<iterator>
43#include <map>
44#include <cmath>
45#include <stdexcept>
46
47namespace theplu {
48namespace yat {
49namespace classifier { 
50
51
52  ///
53  /// @brief Class for Nearest Centroid Classification.
54  ///
55  /// The template argument Distance should be a class modelling
56  /// the concept \ref concept_distance.
57  ///
58  template <typename Distance>
59  class NCC : public SupervisedClassifier
60  {
61 
62  public:
63    ///
64    /// @brief Constructor
65    ///
66    NCC(void);
67   
68    ///
69    /// @brief Constructor
70    ///
71    NCC(const Distance&);
72
73
74    ///
75    /// @brief Destructor
76    ///
77    virtual ~NCC(void);
78
79    ///
80    /// @return the centroids for each class as columns in a matrix.
81    ///
82    const utility::Matrix& centroids(void) const;
83
84    NCC<Distance>* make_classifier(void) const;
85   
86    ///
87    /// Train the classifier with a training data set and
88    /// targets. Centroids are calculated for each class.
89    ///
90    void train(const MatrixLookup&, const Target&);
91
92
93    ///
94    /// Train the classifier with a weighted training data set and
95    /// targets. Centroids are calculated for each class.
96    ///
97    void train(const MatrixLookupWeighted&, const Target&);
98
99   
100    ///
101    /// Calculate the distance to each centroid for test samples
102    ///
103    void predict(const MatrixLookup&, utility::Matrix&) const;
104   
105    ///
106    /// Calculate the distance to each centroid for weighted test samples
107    ///
108    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
109
110   
111  private:
112
113    void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
114    void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
115
116    utility::Matrix* centroids_;
117    bool centroids_nan_;
118    Distance distance_;
119  };
120
121  ///
122  /// The output operator for the NCC class.
123  ///
124  //  std::ostream& operator<< (std::ostream&, const NCC&);
125 
126
127  // templates
128
129  template <typename Distance>
130  NCC<Distance>::NCC() 
131    : SupervisedClassifier(), centroids_(0), centroids_nan_(false)
132  {
133  }
134
135  template <typename Distance>
136  NCC<Distance>::NCC(const Distance& dist) 
137    : SupervisedClassifier(), centroids_(0), centroids_nan_(false), distance_(dist)
138  {
139  }
140
141
142  template <typename Distance>
143  NCC<Distance>::~NCC()   
144  {
145    if(centroids_)
146      delete centroids_;
147  }
148
149
150  template <typename Distance>
151  const utility::Matrix& NCC<Distance>::centroids(void) const
152  {
153    return *centroids_;
154  }
155 
156
157  template <typename Distance>
158  NCC<Distance>* 
159  NCC<Distance>::make_classifier() const 
160  {     
161    return new NCC<Distance>();
162  }
163
164  template <typename Distance>
165  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
166  {   
167    if(centroids_) 
168      delete centroids_;
169    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
170    for(size_t i=0; i<data.rows(); i++) {
171      std::vector<statistics::Averager> class_averager;
172      class_averager.resize(target.nof_classes());
173      for(size_t j=0; j<data.columns(); j++) {
174        class_averager[target(j)].add(data(i,j));
175      }
176      for(size_t c=0;c<target.nof_classes();c++) {
177        (*centroids_)(i,c) = class_averager[c].mean();
178      }
179    }
180  }
181
182
183  template <typename Distance>
184  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
185  {   
186    if(centroids_) 
187      delete centroids_;
188    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
189    for(size_t i=0; i<data.rows(); i++) {
190      std::vector<statistics::AveragerWeighted> class_averager;
191      class_averager.resize(target.nof_classes());
192      for(size_t j=0; j<data.columns(); j++) 
193        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
194      for(size_t c=0;c<target.nof_classes();c++) {
195        if(class_averager[c].sum_w()==0) {
196          centroids_nan_=true;
197        }
198        (*centroids_)(i,c) = class_averager[c].mean();
199      }
200    }
201  }
202
203
204  template <typename Distance>
205  void NCC<Distance>::predict(const MatrixLookup& test,                     
206                              utility::Matrix& prediction) const
207  {   
208    utility::yat_assert<std::runtime_error>
209      (centroids_,"NCC::predict called for untrained classifier");
210    utility::yat_assert<std::runtime_error>
211      (centroids_->rows()==test.rows(),
212       "NCC::predict test data with incorrect number of rows");
213   
214    prediction.resize(centroids_->columns(), test.columns());
215
216    // If weighted training data has resulted in NaN in centroids: weighted calculations
217    if(centroids_nan_) { 
218      predict_weighted(MatrixLookupWeighted(test),prediction);
219    }
220    // If unweighted training data: unweighted calculations
221    else {
222      predict_unweighted(test,prediction);
223    }
224  }
225
226  template <typename Distance>
227  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
228                              utility::Matrix& prediction) const
229  {   
230    utility::yat_assert<std::runtime_error>
231      (centroids_,"NCC::predict called for untrained classifier");
232    utility::yat_assert<std::runtime_error>
233      (centroids_->rows()==test.rows(),
234       "NCC::predict test data with incorrect number of rows");
235   
236    prediction.resize(centroids_->columns(), test.columns());
237    predict_weighted(test,prediction);
238  }
239
240 
241  template <typename Distance>
242  void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
243                                         utility::Matrix& prediction) const
244  {
245    MatrixLookup centroids(*centroids_);
246    for(size_t j=0; j<test.columns();j++)
247      for(size_t k=0; k<centroids_->columns();k++) 
248        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
249                                    centroids.begin_column(k));
250  }
251 
252  template <typename Distance>
253  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
254                                          utility::Matrix& prediction) const
255  {
256    MatrixLookupWeighted weighted_centroids(*centroids_);
257    for(size_t j=0; j<test.columns();j++) 
258      for(size_t k=0; k<centroids_->columns();k++)
259        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
260                                    weighted_centroids.begin_column(k));
261  }
262
263     
264}}} // of namespace classifier, yat, and theplu
265
266#endif
Note: See TracBrowser for help on using the repository browser.