Ignore:
Timestamp:
Feb 25, 2008, 3:32:35 PM (14 years ago)
Author:
Markus Ringnér
Message:

Refs #335, fixed for NCC, working on KNN

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1124 r1142  
    8585
    8686    ///
    87     /// @brief sets the number of neighbors, k.
    88     ///
    89     void k(u_int);
     87    /// @brief sets the number of neighbors, k. If the number of
     88    /// training samples set is smaller than \a k_in, k is set to the number of
     89    /// training samples.
     90    ///
     91    void k(u_int k_in);
    9092
    9193
     
    144146    : SupervisedClassifier(target), data_(data),k_(3)
    145147  {
     148    utility::yat_assert<std::runtime_error>
     149      (data.columns()==target.size(),
     150       "KNN::KNN called with different sizes of target and data");
     151    // k has to be at most the number of training samples.
     152    if(data_.columns()>k_)
     153      k_=data_.columns();
    146154  }
    147155
     
    152160    : SupervisedClassifier(target), data_(data),k_(3)
    153161  {
     162    utility::yat_assert<std::runtime_error>
     163      (data.columns()==target.size(),
     164       "KNN::KNN called with different sizes of target and data");
     165    if(data_.columns()>k_)
     166      k_=data_.columns();
    154167  }
    155168 
     
    232245        (*distances)(i,j) = distance_(training1.begin(), training1.end(),
    233246                                      test1.begin());
    234         utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    235247      }
    236248    }
     
    255267  {
    256268    k_=k;
     269    if(k_>data_.columns())
     270      k_=data_.columns();
    257271  }
    258272
     
    295309                                                 utility::Matrix& prediction) const
    296310  {   
    297     utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     311    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
    298312
    299313    utility::Matrix* distances=calculate_distances(test);
     
    308322    }
    309323    delete distances;
     324
     325    // classes for which there are no training samples should be set
     326    // to nan in the predictions
     327    for(size_t c=0;c<target_.nof_classes(); c++)
     328      if(!target_.size(c))
     329        for(size_t j=0;j<prediction.columns();j++)
     330          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
    310331  }
    311332
Note: See TracChangeset for help on using the changeset viewer.