Ignore:
Timestamp:
Feb 26, 2008, 4:29:50 PM (14 years ago)
Author:
Markus Ringnér
Message:

Fixes #333

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1158 r1160  
    107107    /// class.
    108108    ///
    109     ///
    110     void predict(const DataLookup2D&, utility::Matrix&) const;
     109    void predict(const MatrixLookup&, utility::Matrix&) const;
     110
     111    ///
     112    /// For each sample, calculate the number of neighbors for each
     113    /// class.
     114    ///
     115    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
    111116
    112117
    113118  private:
    114119
    115     // data_ has to be of type DataLookup2D to accomodate both
    116     // MatrixLookup and MatrixLookupWeighted
    117     const DataLookup2D* data_;
     120    const MatrixLookup* data_ml_;
     121    const MatrixLookupWeighted* data_mlw_;
    118122    const Target* target_;
    119123
     
    124128
    125129    NeighborWeighting weighting_;
    126 
    127     ///
    128     /// Calculates the distances between a data set and the training
    129     /// data. The rows are training and the columns test samples,
    130     /// respectively. The returned distance matrix is dynamically
    131     /// generated and needs to be deleted by the caller.
    132     ///
    133     utility::Matrix* calculate_distances(const DataLookup2D&) const;
    134130
    135131    void calculate_unweighted(const MatrixLookup&,
     
    139135                            const MatrixLookupWeighted&,
    140136                            utility::Matrix*) const;
     137
     138    void predict_common(const utility::Matrix& distances,
     139                        utility::Matrix& prediction) const;
     140
    141141  };
    142142 
     
    146146  template <typename Distance, typename NeighborWeighting>
    147147  KNN<Distance, NeighborWeighting>::KNN()
    148     : SupervisedClassifier(),data_(0),target_(0),k_(3)
     148    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
    149149  {
    150150  }
     
    152152  template <typename Distance, typename NeighborWeighting>
    153153  KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
    154     : SupervisedClassifier(),data_(0),target_(0),k_(3), distance_(dist)
     154    : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
    155155  {
    156156  }
     
    162162  }
    163163 
    164   template <typename Distance, typename NeighborWeighting>
    165   utility::Matrix* KNN<Distance, NeighborWeighting>::calculate_distances
    166   (const DataLookup2D& test) const
    167   {
    168     // matrix with training samples as rows and test samples as columns
    169     utility::Matrix* distances =
    170       new utility::Matrix(data_->columns(),test.columns());
    171    
    172    
    173     // unweighted test data
    174     if(const MatrixLookup* test_unweighted =
    175        dynamic_cast<const MatrixLookup*>(&test)) {     
    176       // unweighted training data
    177       if(const MatrixLookup* training_unweighted =
    178          dynamic_cast<const MatrixLookup*>(data_))
    179         calculate_unweighted(*training_unweighted,*test_unweighted,distances);
    180       // weighted training data
    181       else if(const MatrixLookupWeighted* training_weighted =
    182               dynamic_cast<const MatrixLookupWeighted*>(data_))
    183         calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
    184                            distances);             
    185       // Training data can not be of incorrect type
    186     }
    187     // weighted test data
    188     else if (const MatrixLookupWeighted* test_weighted =
    189              dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
    190       // unweighted training data
    191       if(const MatrixLookup* training_unweighted =
    192          dynamic_cast<const MatrixLookup*>(data_)) {
    193         calculate_weighted(MatrixLookupWeighted(*training_unweighted),
    194                            *test_weighted,distances);
    195       }
    196       // weighted training data
    197       else if(const MatrixLookupWeighted* training_weighted =
    198               dynamic_cast<const MatrixLookupWeighted*>(data_))
    199         calculate_weighted(*training_weighted,*test_weighted,distances);             
    200       // Training data can not be of incorrect type
    201     }
    202     else {
    203       std::string str;
    204       str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
    205       throw std::runtime_error(str);
    206     }
    207     return distances;
    208   }
    209164
    210165  template <typename Distance, typename NeighborWeighting>
     
    214169  {
    215170    for(size_t i=0; i<training.columns(); i++) {
    216       classifier::DataLookup1D training1(training,i,false);
    217171      for(size_t j=0; j<test.columns(); j++) {
    218         classifier::DataLookup1D test1(test,j,false);
    219         (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     172        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
     173                                      test.begin_column(j));
    220174        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    221175      }
    222176    }
    223177  }
     178
    224179 
    225180  template <typename Distance, typename NeighborWeighting>
     
    229184   utility::Matrix* distances) const
    230185  {
    231     for(size_t i=0; i<training.columns(); i++) {
    232       classifier::DataLookupWeighted1D training1(training,i,false);
     186    for(size_t i=0; i<training.columns(); i++) {
    233187      for(size_t j=0; j<test.columns(); j++) {
    234         classifier::DataLookupWeighted1D test1(test,j,false);
    235         (*distances)(i,j) = distance_(training1.begin(), training1.end(),
    236                                       test1.begin());
     188        (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
     189                                      test.begin_column(j));
    237190        // If the distance is NaN (no common variables with non-zero weights),
    238191        // the distance is set to infinity to be sorted as a neighbor at the end
     
    277230    if(data.columns()<k_)
    278231      k_=data.columns();
    279     data_=&data;
     232    data_ml_=&data;
     233    data_mlw_=0;
    280234    target_=&target;
    281235    trained_=true;
     
    292246    if(data.columns()<k_)
    293247      k_=data.columns();
    294     data_=&data;
     248    data_ml_=0;
     249    data_mlw_=&data;
    295250    target_=&target;
    296251    trained_=true;
     
    299254
    300255  template <typename Distance, typename NeighborWeighting>
    301   void KNN<Distance, NeighborWeighting>::predict(const DataLookup2D& test,
     256  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
    302257                                                 utility::Matrix& prediction) const
    303258  {   
    304     utility::yat_assert<std::runtime_error>(data_->rows()==test.rows(),"KNN::predict different number of rows in training and test data");
    305 
    306     utility::Matrix* distances=calculate_distances(test);
    307    
     259    // matrix with training samples as rows and test samples as columns
     260    utility::Matrix* distances = 0;
     261    // unweighted training data
     262    if(data_ml_ && !data_mlw_) {
     263      utility::yat_assert<std::runtime_error>
     264        (data_ml_->rows()==test.rows(),
     265         "KNN::predict different number of rows in training and test data");     
     266      distances=new utility::Matrix(data_ml_->columns(),test.columns());
     267      calculate_unweighted(*data_ml_,test,distances);
     268    }
     269    else if (data_mlw_ && !data_ml_) {
     270      // weighted training data
     271      utility::yat_assert<std::runtime_error>
     272        (data_mlw_->rows()==test.rows(),
     273         "KNN::predict different number of rows in training and test data");           
     274      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
     275      calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
     276                         distances);             
     277    }
     278    else {
     279      std::runtime_error("KNN::predict no training data");
     280    }
     281
    308282    prediction.resize(target_->nof_classes(),test.columns(),0.0);
    309     for(size_t sample=0;sample<distances->columns();sample++) {
     283    predict_common(*distances,prediction);
     284    if(distances)
     285      delete distances;
     286  }
     287
     288  template <typename Distance, typename NeighborWeighting>
     289  void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
     290                                                 utility::Matrix& prediction) const
     291  {   
     292    // matrix with training samples as rows and test samples as columns
     293    utility::Matrix* distances=0;
     294    // unweighted training data
     295    if(data_ml_ && !data_mlw_) {
     296      utility::yat_assert<std::runtime_error>
     297        (data_ml_->rows()==test.rows(),
     298         "KNN::predict different number of rows in training and test data");   
     299      distances=new utility::Matrix(data_ml_->columns(),test.columns());
     300      calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
     301    }
     302    // weighted training data
     303    else if (data_mlw_ && !data_ml_) {
     304      utility::yat_assert<std::runtime_error>
     305        (data_mlw_->rows()==test.rows(),
     306         "KNN::predict different number of rows in training and test data");   
     307      distances=new utility::Matrix(data_mlw_->columns(),test.columns());
     308      calculate_weighted(*data_mlw_,test,distances);             
     309    }
     310    else {
     311      std::runtime_error("KNN::predict no training data");
     312    }
     313
     314    prediction.resize(target_->nof_classes(),test.columns(),0.0);
     315    predict_common(*distances,prediction);
     316   
     317    if(distances)
     318      delete distances;
     319  }
     320 
     321  template <typename Distance, typename NeighborWeighting>
     322  void KNN<Distance, NeighborWeighting>::predict_common
     323  (const utility::Matrix& distances, utility::Matrix& prediction) const
     324  {   
     325    for(size_t sample=0;sample<distances.columns();sample++) {
    310326      std::vector<size_t> k_index;
    311       utility::VectorConstView dist=distances->column_const_view(sample);
     327      utility::VectorConstView dist=distances.column_const_view(sample);
    312328      utility::sort_smallest_index(k_index,k_,dist);
    313329      utility::VectorView pred=prediction.column_view(sample);
    314330      weighting_(dist,k_index,*target_,pred);
    315331    }
    316     delete distances;
    317 
     332   
    318333    // classes for which there are no training samples should be set
    319334    // to nan in the predictions
     
    324339  }
    325340
     341
    326342}}} // of namespace classifier, yat, and theplu
    327343
Note: See TracChangeset for help on using the changeset viewer.