source: trunk/yat/classifier/NCC.cc @ 789

Last change on this file since 789 was 789, checked in by Jari Häkkinen, 16 years ago

Addresses #193. vector now works as outlined here. Added some
functionality. Added a clone function that facilitates resizing of
vectors. clone is needed since assignement operator functionality is
changed.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.7 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) The authors contributing to this file.
5
6  This file is part of the yat library, http://lev.thep.lu.se/trac/yat
7
8  The yat library is free software; you can redistribute it and/or
9  modify it under the terms of the GNU General Public License as
10  published by the Free Software Foundation; either version 2 of the
11  License, or (at your option) any later version.
12
13  The yat library is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16  General Public License for more details.
17
18  You should have received a copy of the GNU General Public License
19  along with this program; if not, write to the Free Software
20  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
21  02111-1307, USA.
22*/
23
24#include "NCC.h"
25#include "DataLookup1D.h"
26#include "DataLookup2D.h"
27#include "MatrixLookup.h"
28#include "MatrixLookupWeighted.h"
29#include "Target.h"
30#include "yat/utility/matrix.h"
31#include "yat/utility/vector.h"
32#include "yat/statistics/Distance.h"
33#include "yat/utility/stl_utility.h"
34
35#include<iostream>
36#include<iterator>
37#include <map>
38#include <cmath>
39
40namespace theplu {
41namespace yat {
42namespace classifier {
43
44  NCC::NCC(const MatrixLookup& data, const Target& target, 
45           const statistics::Distance& distance) 
46    : SupervisedClassifier(target), distance_(distance), data_(data)
47  {
48  }
49
50  NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 
51           const statistics::Distance& distance) 
52    : SupervisedClassifier(target), distance_(distance), data_(data)
53  {
54  }
55
56  NCC::~NCC()   
57  {
58  }
59
60
61  const utility::matrix& NCC::centroids(void) const
62  {
63    return centroids_;
64  }
65
66    const DataLookup2D& NCC::data(void) const
67    {
68    return data_;
69    }
70
71  SupervisedClassifier* 
72  NCC::make_classifier(const DataLookup2D& data, const Target& target) const 
73  {     
74    NCC* ncc=0;
75    if(data.weighted()) {
76      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(data),
77                  target,this->distance_);
78    }
79    else {
80      ncc=new NCC(dynamic_cast<const MatrixLookup&>(data),
81                  target,this->distance_);
82    }
83    return ncc;
84  }
85
86
87  bool NCC::train()
88  {   
89    centroids_=utility::matrix(data_.rows(), target_.nof_classes());
90    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
91    for(size_t i=0; i<data_.rows(); i++) {
92      for(size_t j=0; j<data_.columns(); j++) {
93        centroids_(i,target_(j)) += data_(i,j);
94        try {
95          nof_in_class(i,target_(j))+=
96            dynamic_cast<const MatrixLookupWeighted&>(data_).weight(i,j);
97        }
98        catch (std::bad_cast) {
99          nof_in_class(i,target_(j))+=1.0;
100        }
101      }
102    }   
103    centroids_.div_elements(nof_in_class);
104    trained_=true;
105    return trained_;
106  }
107
108
109  void NCC::predict(const DataLookup1D& input, const utility::vector& weights,
110                    utility::vector& prediction) const
111  {
112    prediction.clone(utility::vector(centroids_.columns()));
113
114    utility::vector value(input.size(),0);
115    for(size_t i=0; i<input.size(); i++)
116      value(i)=input(i);
117   
118    // take care of nan's in centroids
119    for(size_t j=0; j<centroids_.columns(); j++) {
120      const utility::vector centroid=utility::vector(centroids_,j,false);
121      utility::vector wc(centroid.size(),0);
122      for(size_t i=0; i<centroid.size(); i++)  { 
123        if(!std::isnan(centroid(i)))
124          wc(i)=1.0;
125      }
126      prediction(j)=distance_(value,centroid,weights,wc);   
127    }
128  }
129
130
131  void NCC::predict(const DataLookup2D& input,                   
132                    utility::matrix& prediction) const
133  {   
134    prediction=utility::matrix(centroids_.columns(), input.columns());   
135    try {   
136      const MatrixLookupWeighted& data=
137        dynamic_cast<const MatrixLookupWeighted&>(input);     
138      for(size_t j=0; j<input.columns();j++) {     
139        DataLookup1D in(input,j,false);
140        utility::vector weights(in.size(),0);
141        for(size_t i=0; i<in.size();i++) 
142          weights(i)=data.weight(i,j);
143        utility::vector out;
144        predict(in,weights,out);
145        prediction.set_column(j,out);
146      }
147    }
148    catch (std::bad_cast) {
149      try {
150        dynamic_cast<const MatrixLookup&>(input);
151        for(size_t j=0; j<input.columns();j++) {     
152          DataLookup1D in(input,j,false);
153          utility::vector weights(in.size(),1.0);
154          utility::vector out;
155          predict(in,weights,out);
156          prediction.set_column(j,out);
157        }
158      }
159      catch (std::bad_cast e) {       
160        std::cerr << "Error in NCC::predict: DataLookup2D of unexpected class. "
161                  <<  "bad_cast: " << e.what() << std::endl;
162      }
163    }
164  }
165
166 
167  // additional operators
168
169//  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
170//    std::copy(ncc.classes().begin(), ncc.classes().end(),
171//              std::ostream_iterator<std::map<double, u_int>::value_type>
172//              (s, "\n"));
173//    s << "\n" << ncc.centroids() << "\n";
174//    return s;
175//  }
176
177}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.