Changeset 931 for trunk/yat/classifier


Ignore:
Timestamp:
Oct 5, 2007, 5:42:25 PM (16 years ago)
Author:
Markus Ringnér
Message:

Working on ticket:259. Removed old Distance see ticket:250

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/KNN.h

    r916 r931  
    168168  {     
    169169    KNN* knn=0;
    170     if(data.weighted()) {
    171       knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
    172                             target);
    173     }
    174     knn->k(this->k());
     170    try {
     171      if(data.weighted()) {
     172        knn=new KNN<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
     173                              target);
     174      } 
     175      knn->k(this->k());
     176    }
     177    catch (std::bad_cast) {
     178      std::string str = "Error in KNN<Distance>::make_classifier: DataLookup2D of unexpected class.";
     179      throw std::runtime_error(str);
     180    }
    175181    return knn;
    176182  }
  • trunk/yat/classifier/NCC.h

    r930 r931  
    3535#include "Target.h"
    3636
     37#include "yat/statistics/Averager.h"
     38#include "yat/statistics/AveragerWeighted.h"
    3739#include "yat/statistics/vector_distance.h"
    3840
     
    158160  {     
    159161    NCC* ncc=0;
    160     if(data.weighted()) {
    161       ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookupWeighted*>(&data),
    162                   target);
    163     }
    164     else {
    165       ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookup*>(&data),
    166                   target);
    167     }
    168     ncc->centroids_=0;
     162    try {
     163      if(data.weighted()) {
     164        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
     165                              target);
     166      }
     167      else {
     168        ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
     169                              target);
     170      }
     171      ncc->centroids_=0;
     172    }
     173    catch (std::bad_cast) {
     174      std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
     175      throw std::runtime_error(str);
     176    }
    169177    return ncc;
    170178  }
     
    177185      delete centroids_;
    178186    centroids_= new utility::matrix(data_.rows(), target_.nof_classes());
    179     utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
    180     const MatrixLookupWeighted* weighted_data =
    181       dynamic_cast<const MatrixLookupWeighted*>(&data_);
    182     bool weighted = weighted_data;
    183 
    184     for(size_t i=0; i<data_.rows(); i++) {
    185       for(size_t j=0; j<data_.columns(); j++) {
    186         (*centroids_)(i,target_(j)) += data_(i,j);
    187         if (weighted)
    188           nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
    189         else
    190           nof_in_class(i,target_(j))+=1.0;
    191       }
    192     }   
    193     centroids_->div(nof_in_class);
     187    // data_ is a MatrixLookup or a MatrixLookupWeighted
     188    if(data_.weighted()) {
     189      const MatrixLookupWeighted* weighted_data =
     190        dynamic_cast<const MatrixLookupWeighted*>(&data_);     
     191      for(size_t i=0; i<data_.rows(); i++) {
     192        std::vector<statistics::AveragerWeighted> class_averager;
     193        class_averager.resize(target_.nof_classes());
     194        for(size_t j=0; j<data_.columns(); j++) {
     195          class_averager[target_(j)].add((*weighted_data)(i,j),
     196                                         weighted_data->weight(i,j));
     197        }
     198        for(size_t c=0;c<target_.nof_classes();c++) {
     199          (*centroids_)(i,c) = class_averager[c].mean();
     200        }
     201      }
     202    }
     203    else {
     204      const MatrixLookup* unweighted_data =
     205        dynamic_cast<const MatrixLookup*>(&data_);     
     206      for(size_t i=0; i<data_.rows(); i++) {
     207        std::vector<statistics::Averager> class_averager;
     208        class_averager.resize(target_.nof_classes());
     209        for(size_t j=0; j<data_.columns(); j++) {
     210          class_averager[target_(j)].add((*unweighted_data)(i,j));
     211        }
     212        for(size_t c=0;c<target_.nof_classes();c++) {
     213          (*centroids_)(i,c) = class_averager[c].mean();
     214        }
     215      }
     216    }
    194217    trained_=true;
    195218    return trained_;
     
    200223                              utility::matrix& prediction) const
    201224  {   
    202     prediction.clone(utility::matrix(centroids_->columns(), input.columns()));   
    203    
    204     // Weighted case
    205     const MatrixLookupWeighted* testdata =
    206       dynamic_cast<const MatrixLookupWeighted*>(&input);     
    207     if (testdata) {
    208       MatrixLookupWeighted weighted_centroids(*centroids_);
     225    prediction.clone(utility::matrix(centroids_->columns(), input.columns()));       
     226    // If both training and test are unweighted: unweighted
     227    // calculations are used
     228    const MatrixLookup* test_unweighted =
     229      dynamic_cast<const MatrixLookup*>(&input);     
     230    if (test_unweighted && !data_.weighted()) {
     231      MatrixLookup unweighted_centroids(*centroids_);
    209232      for(size_t j=0; j<input.columns();j++) {       
    210         DataLookupWeighted1D in(*testdata,j,false);
     233        DataLookup1D in(*test_unweighted,j,false);
    211234        for(size_t k=0; k<centroids_->columns();k++) {
    212           DataLookupWeighted1D centroid(weighted_centroids,k,false);
     235          DataLookup1D centroid(unweighted_centroids,k,false);           
    213236          yat_assert(in.size()==centroid.size());
    214237          prediction(k,j)=statistics::
     
    218241      }
    219242    }
    220     // Non-weighted case
    221     else {
    222       const MatrixLookup* testdata =
    223         dynamic_cast<const MatrixLookup*>(&input);     
    224       if (testdata) {
    225         MatrixLookup unweighted_centroids(*centroids_);
     243    // if either training or test is weighted: weighted
     244    // calculations are used
     245    else {
     246      const MatrixLookupWeighted* test_weighted =
     247        dynamic_cast<const MatrixLookupWeighted*>(&input);     
     248      MatrixLookupWeighted weighted_centroids(*centroids_);
     249      if(test_weighted) {
    226250        for(size_t j=0; j<input.columns();j++) {       
    227           DataLookup1D in(*testdata,j,false);
     251          DataLookupWeighted1D in(*test_weighted,j,false);
    228252          for(size_t k=0; k<centroids_->columns();k++) {
    229             DataLookup1D centroid(unweighted_centroids,k,false);           
     253            DataLookupWeighted1D centroid(weighted_centroids,k,false);
    230254            yat_assert(in.size()==centroid.size());
    231255            prediction(k,j)=statistics::
     
    234258          }
    235259        }
    236       }     
     260      }
     261      else if(data_.weighted() && test_unweighted) {
     262        //        MatrixLookupWeighted test2weighted(*test_unweighted);
     263        // Need to convert MatrixLookup to MatrixLookupWeighted here
     264        // and use it in the code below
     265        for(size_t j=0; j<input.columns();j++) {       
     266          DataLookupWeighted1D in(*test_weighted,j,false);
     267          for(size_t k=0; k<centroids_->columns();k++) {
     268            DataLookupWeighted1D centroid(weighted_centroids,k,false);
     269            yat_assert(in.size()==centroid.size());
     270            prediction(k,j)=statistics::
     271              vector_distance(in.begin(),in.end(),centroid.begin(),
     272                              typename statistics::vector_distance_traits<Distance>::distance());
     273          }
     274        }
     275      }
    237276      else {
    238277        std::string str;
Note: See TracChangeset for help on using the changeset viewer.