Ignore:
Timestamp:
Feb 26, 2008, 2:25:19 PM (14 years ago)
Author:
Markus Ringnér
Message:

Refs #318

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1156 r1157  
    5757  public:
    5858    ///
    59     /// Constructor taking the training data and the target   
    60     /// as input.
    61     ///
    62     KNN(const MatrixLookup&, const Target&);
    63 
    64 
    65     ///
    66     /// Constructor taking the training data with weights and the
    67     /// target as input.
    68     ///
    69     KNN(const MatrixLookupWeighted&, const Target&);
    70 
     59    /// @brief Constructor
     60    ///
     61    KNN(void);
     62
     63
     64    ///
     65    /// @brief Destructor
     66    ///
    7167    virtual ~KNN();
    7268   
    73     //
    74     // @return the training data
    75     //
    76     const DataLookup2D& data(void) const;
    77 
    7869
    7970    ///
     
    8576
    8677    ///
    87     /// @brief sets the number of neighbors, k. If the number of
    88     /// training samples set is smaller than \a k_in, k is set to the number of
    89     /// training samples.
     78    /// @brief sets the number of neighbors, k.
    9079    ///
    9180    void k(u_int k_in);
    9281
    9382
    94     KNN<Distance,NeighborWeighting>* make_classifier(const DataLookup2D&,
    95                          const Target&) const;
    96    
    97     ///
    98     /// Train the classifier using the training data.
    99     /// This function does nothing but is required by the interface.
    100     ///
    101     void train();
     83    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
     84   
     85    ///
     86    /// Train the classifier using training data and target.
     87    ///
     88    /// If the number of training samples set is smaller than \a k_in,
     89    /// k is set to the number of training samples.
     90    ///
     91    void train(const MatrixLookup&, const Target&);
     92
     93    ///
     94    /// Train the classifier using weighted training data and target.
     95    ///
     96    void train(const MatrixLookupWeighted&, const Target&);
    10297
    10398   
     
    114109    // data_ has to be of type DataLookup2D to accomodate both
    115110    // MatrixLookup and MatrixLookupWeighted
    116     const DataLookup2D& data_;
     111    const DataLookup2D* data_;
     112    const Target* target_;
    117113
    118114    // The number of neighbors
     
    143139 
    144140  template <typename Distance, typename NeighborWeighting>
    145   KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target)
    146     : SupervisedClassifier(target), data_(data),k_(3)
    147   {
    148     utility::yat_assert<std::runtime_error>
    149       (data.columns()==target.size(),
    150        "KNN::KNN called with different sizes of target and data");
    151     // k has to be at most the number of training samples.
    152     if(data_.columns()>k_)
    153       k_=data_.columns();
    154   }
    155 
    156 
    157   template <typename Distance, typename NeighborWeighting>
    158   KNN<Distance, NeighborWeighting>::KNN
    159   (const MatrixLookupWeighted& data, const Target& target)
    160     : SupervisedClassifier(target), data_(data),k_(3)
    161   {
    162     utility::yat_assert<std::runtime_error>
    163       (data.columns()==target.size(),
    164        "KNN::KNN called with different sizes of target and data");
    165     if(data_.columns()>k_)
    166       k_=data_.columns();
    167   }
     141  KNN<Distance, NeighborWeighting>::KNN()
     142    : SupervisedClassifier(),data_(0),target_(0),k_(3)
     143  {
     144  }
     145
    168146 
    169147  template <typename Distance, typename NeighborWeighting>
     
    178156    // matrix with training samples as rows and test samples as columns
    179157    utility::Matrix* distances =
    180       new utility::Matrix(data_.columns(),test.columns());
     158      new utility::Matrix(data_->columns(),test.columns());
    181159   
    182160   
     
    186164      // unweighted training data
    187165      if(const MatrixLookup* training_unweighted =
    188          dynamic_cast<const MatrixLookup*>(&data_))
     166         dynamic_cast<const MatrixLookup*>(data_))
    189167        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
    190168      // weighted training data
    191169      else if(const MatrixLookupWeighted* training_weighted =
    192               dynamic_cast<const MatrixLookupWeighted*>(&data_))
     170              dynamic_cast<const MatrixLookupWeighted*>(data_))
    193171        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
    194172                           distances);             
     
    200178      // unweighted training data
    201179      if(const MatrixLookup* training_unweighted =
    202          dynamic_cast<const MatrixLookup*>(&data_)) {
     180         dynamic_cast<const MatrixLookup*>(data_)) {
    203181        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
    204182                           *test_weighted,distances);
     
    206184      // weighted training data
    207185      else if(const MatrixLookupWeighted* training_weighted =
    208               dynamic_cast<const MatrixLookupWeighted*>(&data_))
     186              dynamic_cast<const MatrixLookupWeighted*>(data_))
    209187        calculate_weighted(*training_weighted,*test_weighted,distances);             
    210188      // Training data can not be of incorrect type
     
    252230    }
    253231  }
    254 
    255  
    256   template <typename Distance, typename NeighborWeighting>
    257   const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
    258   {
    259     return data_;
    260   }
    261232 
    262233 
     
    271242  {
    272243    k_=k;
    273     if(k_>data_.columns())
    274       k_=data_.columns();
    275244  }
    276245
     
    278247  template <typename Distance, typename NeighborWeighting>
    279248  KNN<Distance, NeighborWeighting>*
    280   KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data,
    281                                                     const Target& target) const
     249  KNN<Distance, NeighborWeighting>::make_classifier() const
    282250  {     
    283     KNN* knn=0;
    284     try {
    285       if(data.weighted()) {
    286         knn=new KNN<Distance, NeighborWeighting>
    287           (dynamic_cast<const MatrixLookupWeighted&>(data),target);
    288       } 
    289       else {
    290         knn=new KNN<Distance, NeighborWeighting>
    291           (dynamic_cast<const MatrixLookup&>(data),target);
    292       }
    293       knn->k(this->k());
    294     }
    295     catch (std::bad_cast) {
    296       std::string str = "Error in KNN<Distance, NeighborWeighting>";
    297       str += "::make_classifier: DataLookup2D of unexpected class.";
    298       throw std::runtime_error(str);
    299     }
     251    KNN* knn=new KNN<Distance, NeighborWeighting>();
     252    knn->k(this->k());
    300253    return knn;
    301254  }
     
    303256 
    304257  template <typename Distance, typename NeighborWeighting>
    305   void KNN<Distance, NeighborWeighting>::train()
     258  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
     259                                               const Target& target)
    306260  {   
     261    utility::yat_assert<std::runtime_error>
     262      (data.columns()==target.size(),
     263       "KNN::train called with different sizes of target and data");
     264    // k has to be at most the number of training samples.
     265    if(data.columns()<k_)
     266      k_=data.columns();
     267    data_=&data;
     268    target_=&target;
     269    trained_=true;
     270  }
     271
     272  template <typename Distance, typename NeighborWeighting>
     273  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
     274                                               const Target& target)
     275  {   
     276    utility::yat_assert<std::runtime_error>
     277      (data.columns()==target.size(),
     278       "KNN::train called with different sizes of target and data");
     279    // k has to be at most the number of training samples.
     280    if(data.columns()<k_)
     281      k_=data.columns();
     282    data_=&data;
     283    target_=&target;
    307284    trained_=true;
    308285  }
     
    313290                                                 utility::Matrix& prediction) const
    314291  {   
    315     utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
     292    utility::yat_assert<std::runtime_error>(data_->rows()==test.rows(),"KNN::predict different number of rows in training and test data");
    316293
    317294    utility::Matrix* distances=calculate_distances(test);
    318295   
    319     prediction.resize(target_.nof_classes(),test.columns(),0.0);
     296    prediction.resize(target_->nof_classes(),test.columns(),0.0);
    320297    for(size_t sample=0;sample<distances->columns();sample++) {
    321298      std::vector<size_t> k_index;
     
    323300      utility::sort_smallest_index(k_index,k_,dist);
    324301      utility::VectorView pred=prediction.column_view(sample);
    325       weighting_(dist,k_index,target_,pred);
     302      weighting_(dist,k_index,*target_,pred);
    326303    }
    327304    delete distances;
     
    329306    // classes for which there are no training samples should be set
    330307    // to nan in the predictions
    331     for(size_t c=0;c<target_.nof_classes(); c++)
    332       if(!target_.size(c))
     308    for(size_t c=0;c<target_->nof_classes(); c++)
     309      if(!target_->size(c))
    333310        for(size_t j=0;j<prediction.columns();j++)
    334311          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
Note: See TracChangeset for help on using the changeset viewer.