source: trunk/yat/classifier/NCC.cc @ 685

Last change on this file since 685 was 685, checked in by Jari Häkkinen, 15 years ago

Changing a vector view to be a const view as it should be.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.5 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "NCC.h"
25#include "DataLookup1D.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/utility/matrix.h"
31#include "yat/utility/vector.h"
32#include "yat/statistics/Distance.h"
33#include "yat/utility/stl_utility.h"
34
35#include<iostream>
36#include<iterator>
37#include <map>
38#include <cmath>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  NCC::NCC(const MatrixLookup& data, const Target& target, 
45           const statistics::Distance& distance) 
46    : SupervisedClassifier(target), distance_(distance), data_(data)
47  {
48  }
49
50  NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 
51           const statistics::Distance& distance) 
52    : SupervisedClassifier(target), distance_(distance), data_(data)
53  {
54  }
55
56  NCC::~NCC()   
57  {
58  }
59
60
61  SupervisedClassifier* 
62  NCC::make_classifier(const DataLookup2D& data, const Target& target) const 
63  {     
64    NCC* ncc=0;
65    if(data.weighted()) {
66      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(data),
67                  target,this->distance_);
68    }
69    else {
70      ncc=new NCC(dynamic_cast<const MatrixLookup&>(data),
71                  target,this->distance_);
72    }
73    return ncc;
74  }
75
76
77  bool NCC::train()
78  {   
79    centroids_=utility::matrix(data_.rows(), target_.nof_classes());
80    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
81    for(size_t i=0; i<data_.rows(); i++) {
82      for(size_t j=0; j<data_.columns(); j++) {
83        centroids_(i,target_(j)) += data_(i,j);
84        try {
85          nof_in_class(i,target_(j))+=
86            dynamic_cast<const MatrixLookupWeighted&>(data_).weight(i,j);
87        }
88        catch (std::bad_cast) {
89          nof_in_class(i,target_(j))+=1.0;
90        }
91      }
92    }   
93    centroids_.div_elements(nof_in_class);
94    trained_=true;
95    return trained_;
96  }
97
98
99  void NCC::predict(const DataLookup1D& input, const utility::vector& weights,
100                    utility::vector& prediction) const
101  {
102    prediction=utility::vector(centroids_.columns());   
103
104    utility::vector value(input.size(),0);
105    for(size_t i=0; i<input.size(); i++)
106      value(i)=input(i);
107   
108    // take care of nan's in centroids
109    for(size_t j=0; j<centroids_.columns(); j++) {
110      const utility::vector centroid=utility::vector(centroids_,j,false);
111      utility::vector wc(centroid.size(),0);
112      for(size_t i=0; i<centroid.size(); i++)  { 
113        if(!std::isnan(centroid(i)))
114          wc(i)=1.0;
115      }
116      prediction(j)=distance_(value,centroid,weights,wc);   
117    }
118  }
119
120
121  void NCC::predict(const DataLookup2D& input,                   
122                    utility::matrix& prediction) const
123  {   
124    prediction=utility::matrix(centroids_.columns(), input.columns());   
125    try {   
126      const MatrixLookupWeighted& data=
127        dynamic_cast<const MatrixLookupWeighted&>(input);     
128      for(size_t j=0; j<input.columns();j++) {     
129        DataLookup1D in(input,j,false);
130        utility::vector weights(in.size(),0);
131        for(size_t i=0; i<in.size();i++) 
132          weights(i)=data.weight(i,j);
133        utility::vector out;
134        predict(in,weights,out);
135        prediction.set_column(j,out);
136      }
137    }
138    catch (std::bad_cast) {
139      try {
140        dynamic_cast<const MatrixLookup&>(input);
141        for(size_t j=0; j<input.columns();j++) {     
142          DataLookup1D in(input,j,false);
143          utility::vector weights(in.size(),1.0);
144          utility::vector out;
145          predict(in,weights,out);
146          prediction.set_column(j,out);
147        }
148      }
149      catch (std::bad_cast e) {       
150        std::cerr << "Error in NCC::predict: DataLookup2D of unexpected class. "
151                  <<  "bad_cast: " << e.what() << std::endl;
152      }
153    }
154  }
155
156 
157  // additional operators
158
159//  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
160//    std::copy(ncc.classes().begin(), ncc.classes().end(),
161//              std::ostream_iterator<std::map<double, u_int>::value_type>
162//              (s, "\n"));
163//    s << "\n" << ncc.centroids() << "\n";
164//    return s;
165//  }
166
167}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.