Ignore:
Timestamp:
Feb 26, 2008, 4:29:50 PM (14 years ago)
Author:
Markus Ringnér
Message:

Fixes #333

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.h

    r1158 r1160  
    2727*/
    2828
    29 #include "DataLookup1D.h"
    30 #include "DataLookup2D.h"
    31 #include "DataLookupWeighted1D.h"
    3229#include "MatrixLookup.h"
    3330#include "MatrixLookupWeighted.h"
     
    104101    /// Calculate the distance to each centroid for test samples
    105102    ///
    106     void predict(const DataLookup2D&, utility::Matrix&) const;
    107    
     103    void predict(const MatrixLookup&, utility::Matrix&) const;
     104   
     105    ///
     106    /// Calculate the distance to each centroid for weighted test samples
     107    ///
     108    void predict(const MatrixLookupWeighted&, utility::Matrix&) const;
     109
    108110   
    109111  private:
     
    203205
    204206  template <typename Distance>
    205   void NCC<Distance>::predict(const DataLookup2D& test,                     
     207  void NCC<Distance>::predict(const MatrixLookup& test,                     
    206208                              utility::Matrix& prediction) const
    207209  {   
     
    214216    prediction.resize(centroids_->columns(), test.columns());
    215217
    216     // unweighted test data
    217     if (const MatrixLookup* test_unweighted =
    218         dynamic_cast<const MatrixLookup*>(&test)) {
    219       // If weighted training data has resulted in NaN in centroids: weighted calculations
    220       if(centroids_nan_) {
    221         predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction);
    222       }
    223       // If unweighted training data: unweighted calculations
    224       else {
    225         predict_unweighted(*test_unweighted,prediction);
    226       }
    227     }
    228     // weighted test data: weighted calculations
    229     else if (const MatrixLookupWeighted* test_weighted =
    230              dynamic_cast<const MatrixLookupWeighted*>(&test)) {
    231       predict_weighted(*test_weighted,prediction);
    232     }
     218    // If weighted training data has resulted in NaN in centroids: weighted calculations
     219    if(centroids_nan_) {
     220      predict_weighted(MatrixLookupWeighted(test),prediction);
     221    }
     222    // If unweighted training data: unweighted calculations
    233223    else {
    234       std::string str =
    235         "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
    236       throw std::runtime_error(str);
    237     }
    238   }
     224      predict_unweighted(test,prediction);
     225    }
     226  }
     227
     228  template <typename Distance>
     229  void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
     230                              utility::Matrix& prediction) const
     231  {   
     232    utility::yat_assert<std::runtime_error>
     233      (centroids_,"NCC::predict called for untrained classifier");
     234    utility::yat_assert<std::runtime_error>
     235      (centroids_->rows()==test.rows(),
     236       "NCC::predict test data with incorrect number of rows");
     237   
     238    prediction.resize(centroids_->columns(), test.columns());
     239    predict_weighted(test,prediction);
     240  }
     241
    239242 
    240243  template <typename Distance>
Note: See TracChangeset for help on using the changeset viewer.