Changeset 925 for trunk/yat


Ignore:
Timestamp:
Oct 2, 2007, 4:02:08 PM (14 years ago)
Author:
Markus Ringnér
Message:

NCC and IGP have been changed to templates on Distance

Location:
trunk/yat/classifier
Files:
2 deleted
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/IGP.h

    r865 r925  
    2525*/
    2626
     27#include "DataLookup1D.h"
     28#include "MatrixLookup.h"
     29#include "Target.h"
    2730#include "yat/utility/vector.h"
     31#include "yat/utility/yat_assert.h"
     32#include "yat/statistics/vector_distance.h"
     33
     34#include <cmath>
     35#include <limits>
    2836
    2937namespace theplu {
    3038namespace yat {
    31 
    32   namespace statistics {
    33     class Distance;
    34   }
    35 
    3639namespace classifier { 
    3740
     
    4346  /// See Kapp and Tibshirani, Biostatistics (2006).
    4447  ///
     48  template <typename Distance>
    4549  class IGP
    4650  {
     
    5155    /// the distance measure as input.
    5256    ///
    53     IGP(const MatrixLookup&, const Target&, const statistics::Distance&);
     57    IGP(const MatrixLookup&, const Target&);
    5458
    5559    ///
     
    6165    /// @return the IGP score for each class as elements in a vector.
    6266    ///
    63     const utility::vector& score(void) const {return igp_;}
     67    const utility::vector& score(void) const;
    6468
    6569
     
    6771    utility::vector igp_;
    6872
    69     const statistics::Distance& distance_;
    7073    const MatrixLookup& matrix_;
    7174    const Target& target_;
    7275  }; 
     76
     77 
     78  // templates
     79
     80  template <typename Distance>
     81  IGP<Distance>::IGP(const MatrixLookup& data, const Target& target)
     82    : matrix_(data), target_(target)
     83  {   
     84    yat_assert(target_.size()==matrix_.columns());
     85   
     86    // Calculate IGP for each class
     87    igp_.clone(utility::vector(target_.nof_classes()));
     88   
     89    for(u_int i=0; i<target_.size(); i++) {
     90      u_int neighbor=i;
     91      double mindist=std::numeric_limits<double>::max();
     92      const DataLookup1D a(matrix_,i,false);
     93      for(u_int j=0; j<target_.size(); j++) {           
     94        DataLookup1D b(matrix_,j,false);
     95        double dist=statistics::
     96          vector_distance(a.begin,a.end(),b.begin(),
     97                          statistics::vector_distance_traits<Distance>::distace());
     98        if(j!=i && dist<mindist) {
     99          mindist=dist;
     100          neighbor=j;
     101        }
     102      }
     103      if(target_(i)==target_(neighbor))
     104        igp_(target_(i))++;
     105     
     106    }
     107    for(u_int i=0; i<target_.nof_classes(); i++) {
     108      igp_(i)/=static_cast<double>(target_.size(i));
     109    }
     110  }
     111 
     112  template <typename Distance>
     113  IGP<Distance>::~IGP()   
     114  {
     115  }
     116
     117 
     118  template <typename Distance>
     119  const utility::vector& IGP<Distance>::score(void) const
     120  {
     121    return igp_;
     122  }
    73123 
    74124}}} // of namespace classifier, yat, and theplu
  • trunk/yat/classifier/Makefile.am

    r901 r925  
    3838  FeatureSelectorRandom.cc \
    3939  GaussianKernelFunction.cc \
    40   IGP.cc \
    4140  InputRanker.cc \
    4241  Kernel.cc \
     
    4746  MatrixLookupWeighted.cc \
    4847  NBC.cc \
    49   NCC.cc \
    5048  PolynomialKernelFunction.cc \
    5149  Sampler.cc \
     
    8482  MatrixLookupWeighted.h \
    8583  NBC.h \
    86   NCC.h \
    8784  PolynomialKernelFunction.h \
    8885  SVM.h \
  • trunk/yat/classifier/NCC.h

    r909 r925  
    2727*/
    2828
     29#include "DataLookup1D.h"
     30#include "DataLookup2D.h"
     31#include "DataLookupWeighted1D.h"
     32#include "MatrixLookup.h"
     33#include "MatrixLookupWeighted.h"
     34#include "SupervisedClassifier.h"
     35#include "Target.h"
     36
     37#include "yat/statistics/vector_distance.h"
     38
     39#include "yat/utility/Iterator.h"
     40#include "yat/utility/IteratorWeighted.h"
    2941#include "yat/utility/matrix.h"
    30 #include "yat/statistics/vector_distance_ptr.h"
    31 #include "SupervisedClassifier.h"
    32 
     42#include "yat/utility/vector.h"
     43#include "yat/utility/stl_utility.h"
     44#include "yat/utility/yat_assert.h"
     45
     46#include<iostream>
     47#include<iterator>
    3348#include <map>
     49#include <cmath>
     50
    3451
    3552namespace theplu {
    3653namespace yat {
    37 
    38   namespace utlitity {
    39     class vector;
    40   }
    41 
    4254namespace classifier { 
    4355
    44   class DataLookup1D;
    45   class DataLookup2D;
    46   class MatrixLookup;
    47   class MatrixLookupWeighted;
    48   class Target;
    4956
    5057  ///
     
    5259  ///
    5360
     61  template <typename Distance>
    5462  class NCC : public SupervisedClassifier
    5563  {
     
    5765  public:
    5866    ///
    59     /// Constructor taking the training data, the target vector, and
    60     /// the distance measure tag as input.
     67    /// Constructor taking the training data and the target vector as
     68    /// input
    6169    ///
    62     NCC(const MatrixLookup&, const Target&,
    63         const statistics::vector_distance_lookup_weighted_ptr);
    64    
    65     ///
    66     /// Constructor taking the training data with weights, the target
    67     /// vector, the distance measure tag.
     70    NCC(const MatrixLookup&, const Target&);
     71   
     72    ///
     73    /// Constructor taking the training data with weights and the
     74    /// target vector as input.
    6875    ///
    69     NCC(const MatrixLookupWeighted&, const Target&,
    70         const statistics::vector_distance_lookup_weighted_ptr);
     76    NCC(const MatrixLookupWeighted&, const Target&);
    7177
    7278    virtual ~NCC();
     
    98104   
    99105  private:
     106
    100107    utility::matrix centroids_;
    101     const statistics::vector_distance_lookup_weighted_ptr distance_;
    102108
    103109    // data_ has to be of type DataLookup2D to accomodate both
     
    112118  //  std::ostream& operator<< (std::ostream&, const NCC&);
    113119 
     120
     121  // templates
     122
     123  template <typename Distance>
     124  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
     125    : SupervisedClassifier(target), data_(data)
     126  {
     127  }
     128
     129  template <typename Distance>
     130  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
     131    : SupervisedClassifier(target), data_(data)
     132  {
     133  }
     134
     135  template <typename Distance>
     136  NCC<Distance>::~NCC()   
     137  {
     138  }
     139
     140
     141  template <typename Distance>
     142  const utility::matrix& NCC<Distance>::centroids(void) const
     143  {
     144    return centroids_;
     145  }
    114146 
     147
     148  template <typename Distance>
     149  const DataLookup2D& NCC<Distance>::data(void) const
     150  {
     151    return data_;
     152  }
     153 
     154  template <typename Distance>
     155  SupervisedClassifier*
     156  NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const
     157  {     
     158    NCC* ncc=0;
     159    if(data.weighted()) {
     160      ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
     161                  target);
     162    }
     163    else {
     164      ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
     165                  target);
     166    }
     167    return ncc;
     168  }
     169
     170
     171  template <typename Distance>
     172  bool NCC<Distance>::train()
     173  {   
     174    centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes()));
     175    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
     176    const MatrixLookupWeighted* weighted_data =
     177      dynamic_cast<const MatrixLookupWeighted*>(&data_);
     178    bool weighted = weighted_data;
     179
     180    for(size_t i=0; i<data_.rows(); i++) {
     181      for(size_t j=0; j<data_.columns(); j++) {
     182        centroids_(i,target_(j)) += data_(i,j);
     183        if (weighted)
     184          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
     185        else
     186          nof_in_class(i,target_(j))+=1.0;
     187      }
     188    }   
     189    centroids_.div(nof_in_class);
     190    trained_=true;
     191    return trained_;
     192  }
     193
     194  template <typename Distance>
     195  void NCC<Distance>::predict(const DataLookup2D& input,                   
     196                    utility::matrix& prediction) const
     197  {   
     198    prediction.clone(utility::matrix(centroids_.columns(), input.columns()));   
     199
     200    // Weighted case
     201    const MatrixLookupWeighted* testdata =
     202      dynamic_cast<const MatrixLookupWeighted*>(&input);     
     203    if (testdata) {
     204      MatrixLookupWeighted weighted_centroids(centroids_);
     205      for(size_t j=0; j<input.columns();j++) {       
     206        DataLookupWeighted1D in(*testdata,j,false);
     207        for(size_t k=0; k<centroids_.columns();k++) {
     208          DataLookupWeighted1D centroid(weighted_centroids,k,false);
     209
     210          yat_assert(in.size()==centroid.size());
     211          prediction(k,j)=statistics::
     212            vector_distance(in.begin(),in.end(),centroid.begin(),
     213                             typename statistics::vector_distance_traits<Distance>::distance());
     214        }
     215      }
     216    }
     217    else {
     218      std::string str;
     219      str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
     220      throw std::runtime_error(str);
     221    }
     222  }
     223     
    115224}}} // of namespace classifier, yat, and theplu
    116225
Note: See TracChangeset for help on using the changeset viewer.