source: trunk/yat/classifier/NCC.cc @ 898

Last change on this file since 898 was 898, checked in by Markus Ringnér, 15 years ago

A first suggestion for how to adress #250. Also removed contamination of namespace std (see #251).

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date ID
File size: 4.0 KB
Line 
1// $Id$
2
3/*
4  Copyright (C) 2005 Markus Ringnér, Peter Johansson
5  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
6  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
7
8  This file is part of the yat library, http://trac.thep.lu.se/trac/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "NCC.h"
27#include "DataLookup1D.h"
28#include "DataLookup2D.h"
29#include "DataLookupWeighted1D.h"
30#include "MatrixLookup.h"
31#include "MatrixLookupWeighted.h"
32#include "Target.h"
33#include "yat/statistics/vector_distance.h"
34#include "yat/statistics/euclidean_vector_distance.h"
35#include "yat/utility/Iterator.h"
36#include "yat/utility/IteratorWeighted.h"
37#include "yat/utility/matrix.h"
38#include "yat/utility/vector.h"
39#include "yat/utility/stl_utility.h"
40
41#include<iostream>
42#include<iterator>
43#include <map>
44#include <cmath>
45
46namespace theplu {
47namespace yat {
48namespace classifier {
49
50  NCC::NCC(const MatrixLookup& data, const Target& target, 
51           const statistics::vector_distance_lookup_weighted_ptr distance) 
52    : SupervisedClassifier(target), distance_(distance), data_(data)
53  {
54  }
55
56  NCC::NCC(const MatrixLookupWeighted& data, const Target& target, 
57           const statistics::vector_distance_lookup_weighted_ptr distance) 
58    : SupervisedClassifier(target), distance_(distance), data_(data)
59  {
60  }
61
62  NCC::~NCC()   
63  {
64  }
65
66
67  const utility::matrix& NCC::centroids(void) const
68  {
69    return centroids_;
70  }
71 
72
73  const DataLookup2D& NCC::data(void) const
74  {
75    return data_;
76  }
77 
78  SupervisedClassifier* 
79  NCC::make_classifier(const DataLookup2D& data, const Target& target) const 
80  {     
81    NCC* ncc=0;
82    if(data.weighted()) {
83      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(data),
84                  target,this->distance_);
85    }
86    else {
87      ncc=new NCC(dynamic_cast<const MatrixLookup&>(data),
88                  target,this->distance_);
89    }
90    return ncc;
91  }
92
93
94  bool NCC::train()
95  {   
96    centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes()));
97    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
98    const MatrixLookupWeighted* weighted_data = 
99      dynamic_cast<const MatrixLookupWeighted*>(&data_);
100    bool weighted = weighted_data;
101
102    for(size_t i=0; i<data_.rows(); i++) {
103      for(size_t j=0; j<data_.columns(); j++) {
104        centroids_(i,target_(j)) += data_(i,j);
105        if (weighted)
106          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
107        else
108          nof_in_class(i,target_(j))+=1.0;
109      }
110    }   
111    centroids_.div(nof_in_class);
112    trained_=true;
113    return trained_;
114  }
115
116  void NCC::predict(const DataLookup2D& input,                   
117                    utility::matrix& prediction) const
118  {   
119    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));   
120
121    // Weighted case
122    const MatrixLookupWeighted* testdata =
123      dynamic_cast<const MatrixLookupWeighted*>(&input);     
124    if (testdata) {
125      utility::matrix centroid_weights;
126      utility::nan(centroids_,centroid_weights);
127      MatrixLookupWeighted weighted_centroids(centroids_,centroid_weights);
128      for(size_t j=0; j<input.columns();j++) {       
129        DataLookupWeighted1D in(*testdata,j,false);
130        for(size_t k=0; k<centroids_.columns();k++) {
131          DataLookupWeighted1D centroid(weighted_centroids,k,false);
132          prediction(k,j)=(*distance_)(in.begin(),in.end(),centroid.begin());
133        }
134      }
135    }
136    else {
137      std::string str;
138      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
139      throw std::runtime_error(str);
140    }
141  }
142   
143}}} // of namespace classifier, yat, and theplu
Note: See TracBrowser for help on using the repository browser.