Changeset 1142 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 25, 2008, 3:32:35 PM (13 years ago)
Author:
Markus Ringnér
Message:

Refs #335, fixed for NCC, working on KNN

Location:
trunk/yat/classifier
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r1124 r1142  
    8585
    8686    ///
    87     /// @brief sets the number of neighbors, k.
    88     ///
    89     void k(u_int);
     87    /// @brief sets the number of neighbors, k. If the number of
     88    /// training samples set is smaller than \a k_in, k is set to the number of
     89    /// training samples.
     90    ///
     91    void k(u_int k_in);
    9092
    9193
     
    144146    : SupervisedClassifier(target), data_(data),k_(3)
    145147  {
     148    utility::yat_assert<std::runtime_error>
     149      (data.columns()==target.size(),
     150       "KNN::KNN called with different sizes of target and data");
     151    // k has to be at most the number of training samples.
     152    if(data_.columns()>k_)
     153      k_=data_.columns();
    146154  }
    147155
     
    152160    : SupervisedClassifier(target), data_(data),k_(3)
    153161  {
     162    utility::yat_assert<std::runtime_error>
     163      (data.columns()==target.size(),
     164       "KNN::KNN called with different sizes of target and data");
     165    if(data_.columns()>k_)
     166      k_=data_.columns();
    154167  }
    155168 
     
    232245        (*distances)(i,j) = distance_(training1.begin(), training1.end(),
    233246                                      test1.begin());
    234         utility::yat_assert<std::runtime_error>(!std::isnan((*distances)(i,j)));
    235247      }
    236248    }
     
    255267  {
    256268    k_=k;
     269    if(k_>data_.columns())
     270      k_=data_.columns();
    257271  }
    258272
     
    295309                                                 utility::Matrix& prediction) const
    296310  {   
    297     utility::yat_assert<std::runtime_error>(data_.rows()==test.rows());
     311    utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
    298312
    299313    utility::Matrix* distances=calculate_distances(test);
     
    308322    }
    309323    delete distances;
     324
     325    // classes for which there are no training samples should be set
     326    // to nan in the predictions
     327    for(size_t c=0;c<target_.nof_classes(); c++)
     328      if(!target_.size(c))
     329        for(size_t j=0;j<prediction.columns();j++)
     330          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
    310331  }
    311332
  • trunk/yat/classifier/KNN_ReciprocalDistance.h

    r1112 r1142  
    33
    44// $Id$
     5
     6/*
     7  Copyright (C) 2008 Markus Ringnér
     8
     9  This file is part of the yat library, http://trac.thep.lu.se/yat
     10
     11  The yat library is free software; you can redistribute it and/or
     12  modify it under the terms of the GNU General Public License as
     13  published by the Free Software Foundation; either version 2 of the
     14  License, or (at your option) any later version.
     15
     16  The yat library is distributed in the hope that it will be useful,
     17  but WITHOUT ANY WARRANTY; without even the implied warranty of
     18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
     19  General Public License for more details.
     20
     21  You should have received a copy of the GNU General Public License
     22  along with this program; if not, write to the Free Software
     23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
     24  02111-1307, USA.
     25*/
     26
    527
    628#include <vector>
  • trunk/yat/classifier/KNN_ReciprocalRank.h

    r1112 r1142  
    33
    44// $Id$
     5
     6/*
     7  Copyright (C) 2008 Markus Ringnér
     8
     9  This file is part of the yat library, http://trac.thep.lu.se/yat
     10
     11  The yat library is free software; you can redistribute it and/or
     12  modify it under the terms of the GNU General Public License as
     13  published by the Free Software Foundation; either version 2 of the
     14  License, or (at your option) any later version.
     15
     16  The yat library is distributed in the hope that it will be useful,
     17  but WITHOUT ANY WARRANTY; without even the implied warranty of
     18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
     19  General Public License for more details.
     20
     21  You should have received a copy of the GNU General Public License
     22  along with this program; if not, write to the Free Software
     23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
     24  02111-1307, USA.
     25*/
    526
    627
  • trunk/yat/classifier/KNN_Uniform.cc

    r1112 r1142  
    1818                               utility::VectorMutable& prediction) const
    1919  {           
    20     std::vector<size_t>::size_type k=k_sorted.size();
    21     for(size_t j=0;j<k;j++)
     20    for(size_t j=0;j<k_sorted.size();j++)
    2221      prediction(target(k_sorted[j]))+=1.0;           
    2322  }
  • trunk/yat/classifier/KNN_Uniform.h

    r1112 r1142  
    33
    44// $Id$
     5
     6/*
     7  Copyright (C) 2008 Markus Ringnér
     8
     9  This file is part of the yat library, http://trac.thep.lu.se/yat
     10
     11  The yat library is free software; you can redistribute it and/or
     12  modify it under the terms of the GNU General Public License as
     13  published by the Free Software Foundation; either version 2 of the
     14  License, or (at your option) any later version.
     15
     16  The yat library is distributed in the hope that it will be useful,
     17  but WITHOUT ANY WARRANTY; without even the implied warranty of
     18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
     19  General Public License for more details.
     20
     21  You should have received a copy of the GNU General Public License
     22  along with this program; if not, write to the Free Software
     23  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
     24  02111-1307, USA.
     25*/
     26
    527
    628#include <vector>
  • trunk/yat/classifier/NCC.h

    r1124 r1142  
    198198          if(class_averager[c].sum_w()==0) {
    199199            centroids_nan_=true;
    200             (*centroids_)(i,c) = std::numeric_limits<double>::quiet_NaN();
    201200          }
    202           else {
    203             (*centroids_)(i,c) = class_averager[c].mean();
    204           }
     201          (*centroids_)(i,c) = class_averager[c].mean();
    205202        }
    206203      }
     
    262259                                         utility::Matrix& prediction) const
    263260  {
    264     MatrixLookup unweighted_centroids(*centroids_);
    265     for(size_t j=0; j<test.columns();j++) {       
    266       DataLookup1D in(test,j,false);
    267       for(size_t k=0; k<centroids_->columns();k++) {
    268         DataLookup1D centroid(unweighted_centroids,k,false);           
    269         utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    270         prediction(k,j) = distance_(in.begin(), in.end(), centroid.begin());
    271       }
    272     }
    273   }
    274 
     261    MatrixLookup centroids(*centroids_);
     262    for(size_t j=0; j<test.columns();j++)
     263      for(size_t k=0; k<centroids_->columns();k++)
     264        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
     265                                    centroids.begin_column(k));
     266  }
     267 
    275268  template <typename Distance>
    276269  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test,
     
    278271  {
    279272    MatrixLookupWeighted weighted_centroids(*centroids_);
    280     for(size_t j=0; j<test.columns();j++) {       
    281       DataLookupWeighted1D in(test,j,false);
    282       for(size_t k=0; k<centroids_->columns();k++) {
    283         DataLookupWeighted1D centroid(weighted_centroids,k,false);
    284         utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    285         prediction(k,j) = distance_(in.begin(), in.end(), centroid.begin());
    286       }
    287     }
     273    for(size_t j=0; j<test.columns();j++)
     274      for(size_t k=0; k<centroids_->columns();k++)
     275        prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
     276                                    weighted_centroids.begin_column(k));
    288277  }
    289278
Note: See TracChangeset for help on using the changeset viewer.