Ignore:
Timestamp:
Sep 26, 2007, 3:44:19 PM (14 years ago)
Author:
Markus Ringnér
Message:

A first suggestion for how to adress #250. Also removed contamination of namespace std (see #251).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.cc

    r874 r898  
    2727#include "DataLookup1D.h"
    2828#include "DataLookup2D.h"
     29#include "DataLookupWeighted1D.h"
    2930#include "MatrixLookup.h"
    3031#include "MatrixLookupWeighted.h"
    3132#include "Target.h"
     33#include "yat/statistics/vector_distance.h"
     34#include "yat/statistics/euclidean_vector_distance.h"
     35#include "yat/utility/Iterator.h"
     36#include "yat/utility/IteratorWeighted.h"
    3237#include "yat/utility/matrix.h"
    3338#include "yat/utility/vector.h"
    34 #include "yat/statistics/Distance.h"
    3539#include "yat/utility/stl_utility.h"
    3640
     
    4549
    4650  NCC::NCC(const MatrixLookup& data, const Target& target,
    47            const statistics::Distance& distance)
     51           const statistics::vector_distance_lookup_weighted_ptr distance)
    4852    : SupervisedClassifier(target), distance_(distance), data_(data)
    4953  {
     
    5155
    5256  NCC::NCC(const MatrixLookupWeighted& data, const Target& target,
    53            const statistics::Distance& distance)
     57           const statistics::vector_distance_lookup_weighted_ptr distance)
    5458    : SupervisedClassifier(target), distance_(distance), data_(data)
    5559  {
     
    6569    return centroids_;
    6670  }
     71 
    6772
    68     const DataLookup2D& NCC::data(void) const
    69     {
     73  const DataLookup2D& NCC::data(void) const
     74  {
    7075    return data_;
    71     }
    72 
     76  }
     77 
    7378  SupervisedClassifier*
    7479  NCC::make_classifier(const DataLookup2D& data, const Target& target) const
     
    109114  }
    110115
    111 
    112   void NCC::predict(const utility::vector& input, const utility::vector& weights,
    113                     utility::vector& prediction) const
    114   {
    115     prediction.clone(utility::vector(centroids_.columns()));
    116    
    117     // take care of nan's in centroids
    118     for(size_t j=0; j<centroids_.columns(); j++) {
    119       const utility::vector centroid(utility::vector(centroids_,j,false));
    120       utility::vector wc(centroid.size(),0);
    121       for(size_t i=0; i<centroid.size(); i++)  {
    122         if(!std::isnan(centroid(i)))
    123           wc(i)=1.0;
    124       }
    125       prediction(j)=distance_(input,centroid,weights,wc);   
    126     }
    127   }
    128 
    129 
    130116  void NCC::predict(const DataLookup2D& input,                   
    131117                    utility::matrix& prediction) const
    132118  {   
    133     prediction.clone(utility::matrix(centroids_.columns(), input.columns()));
    134     // weighted case
    135     const MatrixLookupWeighted* data =
    136       dynamic_cast<const MatrixLookupWeighted*>(&input); 
    137     if (data) {
    138       for(size_t j=0; j<input.columns();j++) {     
    139         utility::vector in(input.rows(),0);
    140         for(size_t i=0; i<in.size();i++)
    141           in(i)=data->data(i,j);
    142         utility::vector weights(in.size(),0);
    143         for(size_t i=0; i<in.size();i++)
    144           weights(i)=data->weight(i,j);
    145         utility::vector out;
    146         predict(in,weights,out);
    147         prediction.column(j,out);
     119    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));   
     120
     121    // Weighted case
     122    const MatrixLookupWeighted* testdata =
     123      dynamic_cast<const MatrixLookupWeighted*>(&input);     
     124    if (testdata) {
     125      utility::matrix centroid_weights;
     126      utility::nan(centroids_,centroid_weights);
     127      MatrixLookupWeighted weighted_centroids(centroids_,centroid_weights);
     128      for(size_t j=0; j<input.columns();j++) {       
     129        DataLookupWeighted1D in(*testdata,j,false);
     130        for(size_t k=0; k<centroids_.columns();k++) {
     131          DataLookupWeighted1D centroid(weighted_centroids,k,false);
     132          prediction(k,j)=(*distance_)(in.begin(),in.end(),centroid.begin());
     133        }
    148134      }
    149       return;
    150135    }
    151     // non-weighted case
    152     const MatrixLookup* x = dynamic_cast<const MatrixLookup*>(&input);
    153     if (!x){
     136    else {
    154137      std::string str;
    155138      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
    156139      throw std::runtime_error(str);
    157140    }
    158     for(size_t j=0; j<input.columns();j++) {     
    159       utility::vector in(input.rows(),0);
    160       for(size_t i=0; i<in.size();i++)
    161         in(i)=(*data)(i,j);
    162       utility::vector weights(in.size(),1.0);
    163       utility::vector out;
    164       predict(in,weights,out);
    165       prediction.column(j,out);
    166     }
    167141  }
    168 
    169  
    170   // additional operators
    171 
    172 //  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
    173 //    std::copy(ncc.classes().begin(), ncc.classes().end(),
    174 //              std::ostream_iterator<std::map<double, u_int>::value_type>
    175 //              (s, "\n"));
    176 //    s << "\n" << ncc.centroids() << "\n";
    177 //    return s;
    178 //  }
    179 
     142   
    180143}}} // of namespace classifier, yat, and theplu
Note: See TracChangeset for help on using the changeset viewer.