Changeset 1007 for trunk/yat/classifier


Ignore:
Timestamp:
Jan 29, 2008, 10:53:23 AM (14 years ago)
Author:
Markus Ringnér
Message:

Restructuringpredict in NCC. refs #259

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.h

    r1000 r1007  
    220220
    221221  template <typename Distance>
    222   void NCC<Distance>::predict(const DataLookup2D& input,                   
     222  void NCC<Distance>::predict(const DataLookup2D& test,                     
    223223                              utility::matrix& prediction) const
    224224  {   
    225     prediction.clone(utility::matrix(centroids_->columns(), input.columns()));       
    226     // If both training and test are unweighted: unweighted
    227     // calculations are used
    228     const MatrixLookup* test_unweighted =
    229       dynamic_cast<const MatrixLookup*>(&input);     
    230     if (test_unweighted && !data_.weighted()) {
     225    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     226    utility::yat_assert<std::runtime_error>(test.rows()==centroids_->rows());
     227   
     228    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
     229
     230    // unweighted test data
     231    if (const MatrixLookup* test_unweighted =
     232        dynamic_cast<const MatrixLookup*>(&test)) {
    231233      MatrixLookup unweighted_centroids(*centroids_);
    232       for(size_t j=0; j<input.columns();j++) {       
     234      for(size_t j=0; j<test.columns();j++) {       
    233235        DataLookup1D in(*test_unweighted,j,false);
    234236        for(size_t k=0; k<centroids_->columns();k++) {
     
    241243      }
    242244    }
    243     // if either training or test is weighted: weighted
    244     // calculations are used
    245     else {
    246       const MatrixLookupWeighted* test_weighted =
    247         dynamic_cast<const MatrixLookupWeighted*>(&input);     
    248       if(test_weighted) {
    249         MatrixLookupWeighted weighted_centroids(*centroids_);
    250         for(size_t j=0; j<input.columns();j++) {       
    251           DataLookupWeighted1D in(*test_weighted,j,false);
    252           for(size_t k=0; k<centroids_->columns();k++) {
    253             DataLookupWeighted1D centroid(weighted_centroids,k,false);
    254             utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    255             prediction(k,j)=statistics::
    256               vector_distance(in.begin(),in.end(),centroid.begin(),
    257                               typename statistics::vector_distance_traits<Distance>::distance());
    258           }
    259         }
    260       }
    261       else if(data_.weighted() && test_unweighted) {
    262         std::string str =  "Error in NCC<Distance>::predict:";
    263         str += " predicting unweighted data when NCC";
    264         str += " is trained on weighted data is not yet supported";
    265         throw std::runtime_error(str);       
    266       }
    267       else {
    268         std::string str =
    269           "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
    270         throw std::runtime_error(str);
    271       }
     245    // weighted test data
     246    else if (const MatrixLookupWeighted* test_weighted =
     247            dynamic_cast<const MatrixLookupWeighted*>(&test)) {
     248      MatrixLookupWeighted weighted_centroids(*centroids_);
     249      for(size_t j=0; j<test.columns();j++) {       
     250        DataLookupWeighted1D in(*test_weighted,j,false);
     251        for(size_t k=0; k<centroids_->columns();k++) {
     252          DataLookupWeighted1D centroid(weighted_centroids,k,false);
     253          utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
     254          prediction(k,j)=statistics::
     255            vector_distance(in.begin(),in.end(),centroid.begin(),
     256                            typename statistics::vector_distance_traits<Distance>::distance());
     257        }
     258      }
     259    }
     260    else {
     261      std::string str =
     262        "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
     263      throw std::runtime_error(str);
    272264    }
    273265  }
Note: See TracChangeset for help on using the changeset viewer.