source: trunk/yat/classifier/NCC.h @ 930

Last change on this file since 930 was 930, checked in by Markus Ringnér, 15 years ago

Fixed support for MatrixLookup? in NCC. See ticket:259

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 6.4 KB
Line 
1#ifndef _theplu_yat_classifier_ncc_
2#define _theplu_yat_classifier_ncc_
3
4// $Id$
5
6/*
7  Copyright (C) 2005 Markus Ringnér, Peter Johansson
8  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
9  Copyright (C) 2007 Peter Johansson
10
11  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
12
13  The yat library is free software; you can redistribute it and/or
14  modify it under the terms of the GNU General Public License as
15  published by the Free Software Foundation; either version 2 of the
16  License, or (at your option) any later version.
17
18  The yat library is distributed in the hope that it will be useful,
19  but WITHOUT ANY WARRANTY; without even the implied warranty of
20  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
21  General Public License for more details.
22
23  You should have received a copy of the GNU General Public License
24  along with this program; if not, write to the Free Software
25  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
26  02111-1307, USA.
27*/
28
29#include "DataLookup1D.h"
30#include "DataLookup2D.h"
31#include "DataLookupWeighted1D.h"
32#include "MatrixLookup.h"
33#include "MatrixLookupWeighted.h"
34#include "SupervisedClassifier.h"
35#include "Target.h"
36
37#include "yat/statistics/vector_distance.h"
38
39#include "yat/utility/Iterator.h"
40#include "yat/utility/IteratorWeighted.h"
41#include "yat/utility/matrix.h"
42#include "yat/utility/vector.h"
43#include "yat/utility/stl_utility.h"
44#include "yat/utility/yat_assert.h"
45
46#include<iostream>
47#include<iterator>
48#include <map>
49#include <cmath>
50
51
52namespace theplu {
53namespace yat {
54namespace classifier { 
55
56
57  ///
58  /// @brief Class for Nearest Centroid Classification.
59  ///
60
61  template <typename Distance>
62  class NCC : public SupervisedClassifier
63  {
64 
65  public:
66    ///
67    /// Constructor taking the training data and the target vector as
68    /// input
69    ///
70    NCC(const MatrixLookup&, const Target&);
71   
72    ///
73    /// Constructor taking the training data with weights and the
74    /// target vector as input.
75    ///
76    NCC(const MatrixLookupWeighted&, const Target&);
77
78    virtual ~NCC();
79
80    ///
81    /// @return the centroids for each class as columns in a matrix.
82    ///
83    const utility::matrix& centroids(void) const;
84
85    const DataLookup2D& data(void) const;
86
87    SupervisedClassifier* make_classifier(const DataLookup2D&, 
88                                          const Target&) const;
89   
90    ///
91    /// Train the classifier using the training data. Centroids are
92    /// calculated for each class.
93    ///
94    /// @return true if training succedeed.
95    ///
96    bool train();
97
98   
99    ///
100    /// Calculate the distance to each centroid for test samples
101    ///
102    void predict(const DataLookup2D&, utility::matrix&) const;
103   
104   
105  private:
106
107    utility::matrix* centroids_;
108
109    // data_ has to be of type DataLookup2D to accomodate both
110    // MatrixLookup and MatrixLookupWeighted
111    const DataLookup2D& data_;
112
113  };
114
115  ///
116  /// The output operator for the NCC class.
117  ///
118  //  std::ostream& operator<< (std::ostream&, const NCC&);
119 
120
121  // templates
122
123  template <typename Distance>
124  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target) 
125    : SupervisedClassifier(target), centroids_(0), data_(data) 
126  {
127  }
128
129  template <typename Distance>
130  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
131    : SupervisedClassifier(target), centroids_(0), data_(data) 
132  {
133  }
134
135  template <typename Distance>
136  NCC<Distance>::~NCC()   
137  {
138    if(centroids_)
139      delete centroids_;
140  }
141
142  template <typename Distance>
143  const utility::matrix& NCC<Distance>::centroids(void) const
144  {
145    return *centroids_;
146  }
147 
148
149  template <typename Distance>
150  const DataLookup2D& NCC<Distance>::data(void) const
151  {
152    return data_;
153  }
154 
155  template <typename Distance>
156  SupervisedClassifier* 
157  NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const 
158  {     
159    NCC* ncc=0;
160    if(data.weighted()) {
161      ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookupWeighted*>(&data),
162                  target);
163    }
164    else {
165      ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookup*>(&data),
166                  target);
167    }
168    ncc->centroids_=0;
169    return ncc;
170  }
171
172
173  template <typename Distance>
174  bool NCC<Distance>::train()
175  {   
176    if(centroids_) 
177      delete centroids_;
178    centroids_= new utility::matrix(data_.rows(), target_.nof_classes());
179    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
180    const MatrixLookupWeighted* weighted_data = 
181      dynamic_cast<const MatrixLookupWeighted*>(&data_);
182    bool weighted = weighted_data;
183
184    for(size_t i=0; i<data_.rows(); i++) {
185      for(size_t j=0; j<data_.columns(); j++) {
186        (*centroids_)(i,target_(j)) += data_(i,j);
187        if (weighted)
188          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
189        else
190          nof_in_class(i,target_(j))+=1.0;
191      }
192    }   
193    centroids_->div(nof_in_class);
194    trained_=true;
195    return trained_;
196  }
197
198  template <typename Distance>
199  void NCC<Distance>::predict(const DataLookup2D& input,                   
200                              utility::matrix& prediction) const
201  {   
202    prediction.clone(utility::matrix(centroids_->columns(), input.columns()));   
203   
204    // Weighted case
205    const MatrixLookupWeighted* testdata =
206      dynamic_cast<const MatrixLookupWeighted*>(&input);     
207    if (testdata) {
208      MatrixLookupWeighted weighted_centroids(*centroids_);
209      for(size_t j=0; j<input.columns();j++) {       
210        DataLookupWeighted1D in(*testdata,j,false);
211        for(size_t k=0; k<centroids_->columns();k++) {
212          DataLookupWeighted1D centroid(weighted_centroids,k,false);
213          yat_assert(in.size()==centroid.size());
214          prediction(k,j)=statistics::
215            vector_distance(in.begin(),in.end(),centroid.begin(),
216                            typename statistics::vector_distance_traits<Distance>::distance());
217        }
218      }
219    }
220    // Non-weighted case
221    else {
222      const MatrixLookup* testdata =
223        dynamic_cast<const MatrixLookup*>(&input);     
224      if (testdata) {
225        MatrixLookup unweighted_centroids(*centroids_);
226        for(size_t j=0; j<input.columns();j++) {       
227          DataLookup1D in(*testdata,j,false);
228          for(size_t k=0; k<centroids_->columns();k++) {
229            DataLookup1D centroid(unweighted_centroids,k,false);           
230            yat_assert(in.size()==centroid.size());
231            prediction(k,j)=statistics::
232              vector_distance(in.begin(),in.end(),centroid.begin(),
233                              typename statistics::vector_distance_traits<Distance>::distance());
234          }
235        }
236      }     
237      else {
238        std::string str;
239        str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
240        throw std::runtime_error(str);
241      }
242    }
243  }
244     
245}}} // of namespace classifier, yat, and theplu
246
247#endif
Note: See TracBrowser for help on using the repository browser.