Changeset 1107 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 19, 2008, 4:23:52 PM (16 years ago)
Author:
Markus Ringnér
Message:

Ticket #259 fixed for KNN

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1098 r1107  
    9090   
    9191    ///
    92     /// Train the classifier using the training data. Centroids are
    93     /// calculated for each class.
     92    /// Train the classifier using the training data.
     93    /// This function does nothing but is required by the interface.
    9494    ///
    9595    /// @return true if training succedeed.
     
    9999   
    100100    ///
    101     /// Calculate the distance to each centroid for test samples
     101    /// For each sample, calculate the number of neighbours for each
     102    /// class.
     103    ///
    102104    ///
    103105    void predict(const DataLookup2D&, utility::matrix&) const;
     
    121123    ///
    122124    utility::matrix* calculate_distances(const DataLookup2D&) const;
     125    void calculate_unweighted(const MatrixLookup&,
     126                              const MatrixLookup&,
     127                              utility::matrix*) const;
     128    void calculate_weighted(const MatrixLookupWeighted&,
     129                            const MatrixLookupWeighted&,
     130                            utility::matrix*) const;
    123131  };
    124132 
     
    151159      new utility::matrix(data_.columns(),test.columns());
    152160   
     161   
    153162    // unweighted test data
    154163    if(const MatrixLookup* test_unweighted =
    155164       dynamic_cast<const MatrixLookup*>(&test)) {     
    156       for(size_t i=0; i<data_.columns(); i++) {
    157         for(size_t j=0; j<test.columns(); j++) {
    158           classifier::DataLookup1D test(*test_unweighted,j,false);
    159           classifier::DataLookup1D tmp(data_,i,false);
    160           (*distances)(i,j) = distance_(tmp.begin(), tmp.end(), test.begin());
    161           utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    162         }
    163       }
     165      // unweighted training data
     166      if(const MatrixLookup* training_unweighted =
     167         dynamic_cast<const MatrixLookup*>(&data_))
     168        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
     169      // weighted training data
     170      else if(const MatrixLookupWeighted* training_weighted =
     171              dynamic_cast<const MatrixLookupWeighted*>(&data_))
     172        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
     173                           distances);             
     174      // Training data can not be of incorrect type
    164175    }
    165176    // weighted test data
     177    else if (const MatrixLookupWeighted* test_weighted =
     178             dynamic_cast<const MatrixLookupWeighted*>(&test)) {     
     179      // unweighted training data
     180      if(const MatrixLookup* training_unweighted =
     181         dynamic_cast<const MatrixLookup*>(&data_)) {
     182        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
     183                           *test_weighted,distances);
     184      }
     185      // weighted training data
     186      else if(const MatrixLookupWeighted* training_weighted =
     187              dynamic_cast<const MatrixLookupWeighted*>(&data_))
     188        calculate_weighted(*training_weighted,*test_weighted,distances);             
     189      // Training data can not be of incorrect type
     190    }
    166191    else {
    167       const MatrixLookupWeighted* data_weighted =
    168         dynamic_cast<const MatrixLookupWeighted*>(&data_);
    169       const MatrixLookupWeighted* test_weighted =
    170         dynamic_cast<const MatrixLookupWeighted*>(&test);               
    171       if(data_weighted && test_weighted) {
    172         for(size_t i=0; i<data_.columns(); i++) {
    173           classifier::DataLookupWeighted1D training(*data_weighted,i,false);
    174           for(size_t j=0; j<test.columns(); j++) {
    175             classifier::DataLookupWeighted1D test(*test_weighted,j,false);
    176             utility::yat_assert<std::runtime_error>(training.size()==test.size());
    177             (*distances)(i,j) = distance_(training.begin(), training.end(),
    178                                           test.begin());
    179             utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    180           }
    181         }
    182       }
    183       else {
    184         std::string str;
    185         str = "Error in KNN::calculate_distances: Only support when training and test data both are either MatrixLookup or MatrixLookupWeighted";
    186         throw std::runtime_error(str);
    187       }
     192      std::string str;
     193      str = "Error in KNN::calculate_distances: test data has to be either MatrixLookup or MatrixLookupWeighted";
     194      throw std::runtime_error(str);
    188195    }
    189196    return distances;
    190197  }
     198
     199  template <typename Distance>
     200  void  KNN<Distance>:: calculate_unweighted(const MatrixLookup& training,
     201                                             const MatrixLookup& test,
     202                                             utility::matrix* distances) const
     203  {
     204    for(size_t i=0; i<training.columns(); i++) {
     205      classifier::DataLookup1D training1(training,i,false);
     206      for(size_t j=0; j<test.columns(); j++) {
     207        classifier::DataLookup1D test1(test,j,false);
     208        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     209        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
     210      }
     211    }
     212  }
     213 
     214  template <typename Distance>
     215  void  KNN<Distance>:: calculate_weighted(const MatrixLookupWeighted& training,
     216                                           const MatrixLookupWeighted& test,
     217                                           utility::matrix* distances) const
     218  {
     219    for(size_t i=0; i<training.columns(); i++) {
     220      classifier::DataLookupWeighted1D training1(training,i,false);
     221      for(size_t j=0; j<test.columns(); j++) {
     222        classifier::DataLookupWeighted1D test1(test,j,false);
     223        (*distances)(i,j) = distance_(training1.begin(), training1.end(), test1.begin());
     224        utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
     225      }
     226    }
     227  }
     228
    191229 
    192230  template <typename Distance>
Note: See TracChangeset for help on using the changeset viewer.