Changeset 898 for trunk/yat/classifier


Ignore:
Timestamp:
Sep 26, 2007, 3:44:19 PM (14 years ago)
Author:
Markus Ringnér
Message:

A first suggestion for how to adress #250. Also removed contamination of namespace std (see #251).

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.cc

    r874 r898  
    2727#include "DataLookup1D.h"
    2828#include "DataLookup2D.h"
     29#include "DataLookupWeighted1D.h"
    2930#include "MatrixLookup.h"
    3031#include "MatrixLookupWeighted.h"
    3132#include "Target.h"
     33#include "yat/statistics/vector_distance.h"
     34#include "yat/statistics/euclidean_vector_distance.h"
     35#include "yat/utility/Iterator.h"
     36#include "yat/utility/IteratorWeighted.h"
    3237#include "yat/utility/matrix.h"
    3338#include "yat/utility/vector.h"
    34 #include "yat/statistics/Distance.h"
    3539#include "yat/utility/stl_utility.h"
    3640
     
    4549
    4650  NCC::NCC(const MatrixLookup& data, const Target& target,
    47            const statistics::Distance& distance)
     51           const statistics::vector_distance_lookup_weighted_ptr distance)
    4852    : SupervisedClassifier(target), distance_(distance), data_(data)
    4953  {
     
    5155
    5256  NCC::NCC(const MatrixLookupWeighted& data, const Target& target,
    53            const statistics::Distance& distance)
     57           const statistics::vector_distance_lookup_weighted_ptr distance)
    5458    : SupervisedClassifier(target), distance_(distance), data_(data)
    5559  {
     
    6569    return centroids_;
    6670  }
     71 
    6772
    68     const DataLookup2D& NCC::data(void) const
    69     {
     73  const DataLookup2D& NCC::data(void) const
     74  {
    7075    return data_;
    71     }
    72 
     76  }
     77 
    7378  SupervisedClassifier*
    7479  NCC::make_classifier(const DataLookup2D& data, const Target& target) const
     
    109114  }
    110115
    111 
    112   void NCC::predict(const utility::vector& input, const utility::vector& weights,
    113                     utility::vector& prediction) const
    114   {
    115     prediction.clone(utility::vector(centroids_.columns()));
    116    
    117     // take care of nan's in centroids
    118     for(size_t j=0; j<centroids_.columns(); j++) {
    119       const utility::vector centroid(utility::vector(centroids_,j,false));
    120       utility::vector wc(centroid.size(),0);
    121       for(size_t i=0; i<centroid.size(); i++)  {
    122         if(!std::isnan(centroid(i)))
    123           wc(i)=1.0;
    124       }
    125       prediction(j)=distance_(input,centroid,weights,wc);   
    126     }
    127   }
    128 
    129 
    130116  void NCC::predict(const DataLookup2D& input,                   
    131117                    utility::matrix& prediction) const
    132118  {   
    133     prediction.clone(utility::matrix(centroids_.columns(), input.columns()));
    134     // weighted case
    135     const MatrixLookupWeighted* data =
    136       dynamic_cast<const MatrixLookupWeighted*>(&input); 
    137     if (data) {
    138       for(size_t j=0; j<input.columns();j++) {     
    139         utility::vector in(input.rows(),0);
    140         for(size_t i=0; i<in.size();i++)
    141           in(i)=data->data(i,j);
    142         utility::vector weights(in.size(),0);
    143         for(size_t i=0; i<in.size();i++)
    144           weights(i)=data->weight(i,j);
    145         utility::vector out;
    146         predict(in,weights,out);
    147         prediction.column(j,out);
     119    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));   
     120
     121    // Weighted case
     122    const MatrixLookupWeighted* testdata =
     123      dynamic_cast<const MatrixLookupWeighted*>(&input);     
     124    if (testdata) {
     125      utility::matrix centroid_weights;
     126      utility::nan(centroids_,centroid_weights);
     127      MatrixLookupWeighted weighted_centroids(centroids_,centroid_weights);
     128      for(size_t j=0; j<input.columns();j++) {       
     129        DataLookupWeighted1D in(*testdata,j,false);
     130        for(size_t k=0; k<centroids_.columns();k++) {
     131          DataLookupWeighted1D centroid(weighted_centroids,k,false);
     132          prediction(k,j)=(*distance_)(in.begin(),in.end(),centroid.begin());
     133        }
    148134      }
    149       return;
    150135    }
    151     // non-weighted case
    152     const MatrixLookup* x = dynamic_cast<const MatrixLookup*>(&input);
    153     if (!x){
     136    else {
    154137      std::string str;
    155138      str = "Error in NCC::predict: DataLookup2D of unexpected class.";
    156139      throw std::runtime_error(str);
    157140    }
    158     for(size_t j=0; j<input.columns();j++) {     
    159       utility::vector in(input.rows(),0);
    160       for(size_t i=0; i<in.size();i++)
    161         in(i)=(*data)(i,j);
    162       utility::vector weights(in.size(),1.0);
    163       utility::vector out;
    164       predict(in,weights,out);
    165       prediction.column(j,out);
    166     }
    167141  }
    168 
    169  
    170   // additional operators
    171 
    172 //  std::ostream& operator<< (std::ostream& s, const NCC& ncc) {
    173 //    std::copy(ncc.classes().begin(), ncc.classes().end(),
    174 //              std::ostream_iterator<std::map<double, u_int>::value_type>
    175 //              (s, "\n"));
    176 //    s << "\n" << ncc.centroids() << "\n";
    177 //    return s;
    178 //  }
    179 
     142   
    180143}}} // of namespace classifier, yat, and theplu
  • trunk/yat/classifier/NCC.h

    r874 r898  
    2828
    2929#include "yat/utility/matrix.h"
     30#include "yat/statistics/vector_distance_ptr.h"
    3031#include "SupervisedClassifier.h"
    3132
     
    3738  namespace utlitity {
    3839    class vector;
    39   }
    40 
    41   namespace statistics {
    42     class Distance;
    4340  }
    4441
     
    6158    ///
    6259    /// Constructor taking the training data, the target vector, and
    63     /// the distance measure as input.
     60    /// the distance measure tag as input.
    6461    ///
    65     NCC(const MatrixLookup&, const Target&, const statistics::Distance&);
     62    NCC(const MatrixLookup&, const Target&,
     63        const statistics::vector_distance_lookup_weighted_ptr);
    6664   
    6765    ///
    6866    /// Constructor taking the training data with weights, the target
    69     /// vector, the distance measure, and a weight matrix for the
    70     /// training data as input.
     67    /// vector, the distance measure tag.
    7168    ///
    72     NCC(const MatrixLookupWeighted&, const Target&, const statistics::Distance&);
     69    NCC(const MatrixLookupWeighted&, const Target&,
     70        const statistics::vector_distance_lookup_weighted_ptr);
    7371
    7472    virtual ~NCC();
     
    9795    ///
    9896    void predict(const DataLookup2D&, utility::matrix&) const;
    99 
    100 
     97   
     98   
    10199  private:
    102100    utility::matrix centroids_;
    103     const statistics::Distance& distance_;                 
     101    const statistics::vector_distance_lookup_weighted_ptr distance_;                 
    104102
    105103    // data_ has to be of type DataLookup2D to accomodate both
    106104    // MatrixLookup and MatrixLookupWeighted
    107105    const DataLookup2D& data_;
    108 
    109     ///
    110     /// Calculate the distance to each centroid for a test sample
    111     ///
    112     void predict(const utility::vector&, const utility::vector&,
    113                  utility::vector&) const;
    114106
    115107  };
Note: See TracChangeset for help on using the changeset viewer.