Changeset 1033 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 5, 2008, 12:12:12 PM (16 years ago)
Author:
Markus Ringnér
Message:

Working on #259

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.h

    r1031 r1033  
    107107  private:
    108108
     109    void predict_unweighted(const MatrixLookup&, utility::matrix&) const;
     110    void predict_weighted(const MatrixLookupWeighted&, utility::matrix&) const;   
     111
    109112    utility::matrix* centroids_;
     113    bool centroids_nan_;
    110114
    111115    // data_ has to be of type DataLookup2D to accomodate both
    112116    // MatrixLookup and MatrixLookupWeighted
    113117    const DataLookup2D& data_;
    114     bool centroids_nan_;
    115118  };
    116119
     
    125128  template <typename Distance>
    126129  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
    127     : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false)
     130    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    128131  {
    129132  }
     
    131134  template <typename Distance>
    132135  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
    133     : SupervisedClassifier(target), centroids_(0), data_(data), centroids_nan_(false) 
     136    : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    134137  {
    135138  }
     
    231234    prediction.clone(utility::matrix(centroids_->columns(), test.columns()));       
    232235
    233     // unweighted test data and no nan's in centroids
    234     // Markus: Should test centroid_nan_ here!!!
     236    // unweighted test data
    235237    if (const MatrixLookup* test_unweighted =
    236238        dynamic_cast<const MatrixLookup*>(&test)) {
    237       MatrixLookup unweighted_centroids(*centroids_);
    238       for(size_t j=0; j<test.columns();j++) {       
    239         DataLookup1D in(*test_unweighted,j,false);
    240         for(size_t k=0; k<centroids_->columns();k++) {
    241           DataLookup1D centroid(unweighted_centroids,k,false);           
    242           utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    243           prediction(k,j)=statistics::
    244             distance(in.begin(),in.end(),centroid.begin(),
    245                             typename statistics::distance_traits<Distance>::distance());
    246         }
    247       }
    248     }
    249     // weighted test data
     239      // If weighted training data resulting in NaN in centroids: weighted calculations
     240      if(centroids_nan_) {
     241        //        predict_weighted(MatrixLookupWeighted(*test_unweighted),prediction);
     242        std::string str =
     243        "Error in NCC<Distance>::predict: weighted training unweighted test not implemented yet";
     244      throw std::runtime_error(str);
     245      }
     246      // If unweighted training data: unweighted calculations
     247      else {
     248        predict_unweighted(*test_unweighted,prediction);
     249      }
     250    }
     251    // weighted test data: weighted calculations
    250252    else if (const MatrixLookupWeighted* test_weighted =
    251             dynamic_cast<const MatrixLookupWeighted*>(&test)) {
    252       MatrixLookupWeighted weighted_centroids(*centroids_);
    253       for(size_t j=0; j<test.columns();j++) {       
    254         DataLookupWeighted1D in(*test_weighted,j,false);
    255         for(size_t k=0; k<centroids_->columns();k++) {
    256           DataLookupWeighted1D centroid(weighted_centroids,k,false);
    257           utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
    258           prediction(k,j)=statistics::
    259             distance(in.begin(),in.end(),centroid.begin(),
    260                             typename statistics::distance_traits<Distance>::distance());
    261         }
    262       }
     253             dynamic_cast<const MatrixLookupWeighted*>(&test)) {
     254      predict_weighted(*test_weighted,prediction);
    263255    }
    264256    else {
     
    268260    }
    269261  }
     262 
     263  template <typename Distance>
     264  void NCC<Distance>::predict_unweighted(const MatrixLookup& test,
     265                                         utility::matrix& prediction) const
     266  {
     267    MatrixLookup unweighted_centroids(*centroids_);
     268    for(size_t j=0; j<test.columns();j++) {       
     269      DataLookup1D in(test,j,false);
     270      for(size_t k=0; k<centroids_->columns();k++) {
     271        DataLookup1D centroid(unweighted_centroids,k,false);           
     272        utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
     273        prediction(k,j)=statistics::
     274          distance(in.begin(),in.end(),centroid.begin(),
     275                   typename statistics::distance_traits<Distance>::distance());
     276      }
     277    }
     278  }
     279
     280  template <typename Distance>
     281  void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test,
     282                                          utility::matrix& prediction) const
     283  {
     284    MatrixLookupWeighted weighted_centroids(*centroids_);
     285    for(size_t j=0; j<test.columns();j++) {       
     286      DataLookupWeighted1D in(test,j,false);
     287      for(size_t k=0; k<centroids_->columns();k++) {
     288        DataLookupWeighted1D centroid(weighted_centroids,k,false);
     289        utility::yat_assert<std::runtime_error>(in.size()==centroid.size());
     290        prediction(k,j)=statistics::
     291          distance(in.begin(),in.end(),centroid.begin(),
     292                   typename statistics::distance_traits<Distance>::distance());
     293      }
     294    }
     295  }
     296
    270297     
    271298}}} // of namespace classifier, yat, and theplu
Note: See TracChangeset for help on using the changeset viewer.