Changeset 476


Ignore:
Timestamp:
Dec 22, 2005, 5:37:27 PM (16 years ago)
Author:
Markus Ringnér
Message:

Added interface class for supervised classifiers. Fixed NCC to comply with Peters last check-in. Fixed bugs in AveragerPairWeighted?

Location:
trunk/lib
Files:
2 added
6 edited

Legend:

Unmodified
Added
Removed
  • trunk/lib/classifier/Makefile.am

    r475 r476  
    2121  NCC.cc \
    2222  PolynomialKernelFunction.cc \
     23  SupervisedClassifier.cc \
    2324  SVM.cc \
    2425  Target.cc
     26
    2527
    2628
     
    4143  NCC.h \
    4244  PolynomialKernelFunction.h \
     45  SupervisedClassifier.h \
    4346  SVM.h \
    4447  Target.h \
  • trunk/lib/classifier/NCC.cc

    r475 r476  
    1818namespace classifier {
    1919
    20   NCC::NCC(const MatrixLookup& data, const Target& target)
     20  NCC::NCC(const DataLookup2D& data, const Target& target)
     21    : SupervisedClassifier(data,target)
    2122  {
    22     Target sorted_target(target);
     23  }
     24
     25  SupervisedClassifier*
     26  NCC::make_classifier(const DataLookup2D& data,
     27                       const Target& target) const
     28  {
     29    SupervisedClassifier* sc= new NCC(data,target);
     30    // Here all classifier parameters should be copied from this to sc
     31    return sc;
     32  }
     33
     34
     35  bool NCC::train()
     36  {
     37    Target sorted_target(target_);
    2338    sorted_target.sort();
    2439   
     
    3550   
    3651    // Calculate the centroids for each class
    37     centroids_=gslapi::matrix(data.rows(),classes_.size());
    38     gslapi::matrix nof_in_class(data.rows(),classes_.size());
    39     for(size_t i=0; i<data.rows(); i++) {
    40       for(size_t j=0; j<data.columns(); j++) {
    41         if(!std::isnan(data(i,j))) {
    42           centroids_(i,classes_[target(j)]) += data(i,j);
    43           nof_in_class(i,classes_[target(j)])++;
     52    centroids_=gslapi::matrix(matrix_.rows(),classes_.size());
     53    gslapi::matrix nof_in_class(matrix_.rows(),classes_.size());
     54    for(size_t i=0; i<matrix_.rows(); i++) {
     55      for(size_t j=0; j<matrix_.columns(); j++) {
     56        if(!std::isnan(matrix_(i,j))) {
     57          centroids_(i,classes_[target_(j)]) += matrix_(i,j);
     58          nof_in_class(i,classes_[target_(j)])++;
    4459        }
    4560      }
    4661    }
    4762    centroids_.div_elements(nof_in_class);
     63    trained_=true;
     64    return trained_;
    4865  }
    4966
    5067  void NCC::predict(const gslapi::vector& input,
    5168                    statistics::Score& score,
    52                     gslapi::vector& prediction)
     69                    gslapi::vector& prediction) const
    5370  {
    54     prediction=gslapi::vector(classes_.size());   
     71    prediction=gslapi::vector(centroids_.columns());   
    5572    gslapi::vector w(input.size(),0);
    5673    for(size_t i=0; i<input.size(); i++)  // take care of missing values
     
    6279      prediction(j)=ap.correlation();
    6380      ap.reset();
     81    }
     82  }
     83
     84
     85  void NCC::predict(const gslapi::matrix& input,
     86                    statistics::Score& score,
     87                    gslapi::matrix& prediction) const
     88  {
     89    prediction=gslapi::matrix(centroids_.columns(), input.columns());   
     90    for(size_t j=0; j<input.columns();j++) {     
     91      gslapi::vector in(input,j,false);
     92      gslapi::vector out;
     93      predict(in,score,out);
     94      prediction.set_column(j,out);
    6495    }
    6596  }
  • trunk/lib/classifier/NCC.h

    r475 r476  
    55
    66#include <c++_tools/gslapi/matrix.h>
     7#include <c++_tools/classifier/SupervisedClassifier.h>
    78
    89#include <map>
    910
    1011namespace theplu {
    11 namespace gslapi {
    12   class vector;
    13 }
    14 namespace statistics {
    15   class Score;
    16 }
     12
     13  // forward declarations
     14  namespace statistics {
     15    class Score;
     16  }
    1717
    1818namespace classifier { 
    1919
    20   class MatrixLookup;
    21   class Score;
    2220  class Target;
    23   class DataLookup1D;
     21  class DataLookup2D;
    2422
    2523  ///
     
    2725  ///
    2826
    29   class NCC
     27  class NCC : public SupervisedClassifier
    3028  {
    3129 
    3230  public:
    33     ///
    34     /// Default constructor (not implemented)
    35     ///
    36     NCC();
    37 
    3831    ///
    3932    /// Constructor taking the training data and the target vector as
    4033    /// input. Performs the training of the NCC.
    4134    ///
    42     NCC(const MatrixLookup&, const Target&);
     35    NCC(const DataLookup2D&, const Target&);
    4336
    44     ///
    45     /// @todo Copy constructor.
    46     ///
    47     NCC(const NCC&);
    48 
    49     ///
    50     /// @todo The istream constructor.
    51     ///
    52     NCC(std::istream&);
    53          
    5437    const gslapi::matrix& centroids(void) const {return centroids_;}
    5538
    5639    const std::map<double,u_int>& classes(void) const {return classes_;}
     40
     41    inline SupervisedClassifier*
     42    make_classifier(const DataLookup2D&, const Target&) const;
    5743   
     44    bool train();
     45
    5846
    5947    ///
     
    6149    ///
    6250    void predict(const gslapi::vector&, statistics::Score&,
    63                  gslapi::vector&);
     51                 gslapi::vector&) const;
    6452   
     53    ///
     54    /// Calculate the scores to each centroid for test samples
     55    ///
     56    void predict(const gslapi::matrix&, statistics::Score&,
     57                 gslapi::matrix&) const;
     58
     59
    6560  private:
    6661    gslapi::matrix centroids_;
  • trunk/lib/statistics/AveragerPairWeighted.cc

    r475 r476  
    1515                                  const double wx, const double wy)
    1616  {
     17    if(wx==0.0 || wy==0.0) {
     18      return;
     19    }
    1720    double w=sqrt(wx*wy);
    1821    x_.add(x,w);
    1922    y_.add(y,w);
    20     wxy_ += w*x*y;
     23    wxy_ += w*x*y;   
     24    w_ +=w;
    2125  }
    2226
  • trunk/lib/statistics/AveragerPairWeighted.h

    r475 r476  
    2121  public:
    2222
     23    AveragerPairWeighted()
     24      : wxy_(0), w_(0)
     25    {
     26    }
     27
    2328    ///
    2429    /// Adding a pair of data points with value \a x and \a y, and
     
    4045    ///
    4146    inline double correlation(void) const
    42     { return (wxy_ - x_.mean()*y_.mean() )/ sqrt(x_.variance()*y_.variance()); }
     47    { return (wxy_/w_ - x_.mean()*y_.mean() )/ sqrt(x_.variance()*y_.variance()); }
    4348 
    4449    ///
    4550    /// @reset
    4651    ///
    47     inline void reset(void) { x_.reset(); y_.reset(); wxy_=0; }
     52    inline void reset(void) { x_.reset(); y_.reset(); wxy_=0; w_=0; }
    4853
    4954  private:
     
    5156    AveragerWeighted y_;
    5257    double wxy_;
     58    double w_;
    5359
    5460  };
  • trunk/lib/statistics/AveragerWeighted.h

    r475 r476  
    4040    ///
    4141    inline void add(const double d,const double w=1)
    42     { w_.add(w); wx_.add(w*d); wwx_+=w*w*d; wxx_+=w*d*d; }
     42    { if(w==0) return; w_.add(w); wx_.add(w*d); wwx_+=w*w*d; wxx_+=w*d*d; }
    4343
    4444    ///
Note: See TracChangeset for help on using the changeset viewer.