Changeset 1102 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 18, 2008, 6:00:34 AM (13 years ago)
Author:
Peter
Message:

fixes #314

Location:
trunk/yat/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/SVM.cc

    r1100 r1102  
    2525
    2626#include "SVM.h"
    27 #include "DataLookup2D.h"
     27#include "KernelLookup.h"
    2828#include "Target.h"
    2929#include "yat/random/random.h"
     
    8989
    9090
     91  /*
    9192  const DataLookup2D& SVM::data(void) const
    9293  {
    9394    return *kernel_;
    9495  }
     96  */
    9597
    9698
     
    127129  }
    128130
    129   void SVM::predict(const DataLookup2D& input, utility::matrix& prediction) const
    130   {
    131     try {
    132       const KernelLookup& input_kernel =dynamic_cast<const KernelLookup&>(input);
    133       assert(input.rows()==alpha_.size());
    134       prediction.resize(2,input.columns(),0);
    135       for (size_t i = 0; i<input.columns(); i++){
    136         for (size_t j = 0; j<input.rows(); j++){
    137           prediction(0,i) += target(j)*alpha_(j)*input_kernel(j,i);
    138           assert(target(j));
    139         }
    140         prediction(0,i) = margin_ * (prediction(0,i) + bias_);
    141       }
    142      
    143       for (size_t i = 0; i<prediction.columns(); i++)
    144         prediction(1,i) = -prediction(0,i);
    145     }
    146     catch (std::bad_cast) {
    147       std::string str =
    148         "Error in SVM::predict: DataLookup2D of unexpected class.";
    149       throw std::runtime_error(str);
     131  void SVM::predict(const KernelLookup& input, utility::matrix& prediction) const
     132  {
     133    assert(input.rows()==alpha_.size());
     134    prediction.resize(2,input.columns(),0);
     135    for (size_t i = 0; i<input.columns(); i++){
     136      for (size_t j = 0; j<input.rows(); j++){
     137        prediction(0,i) += target(j)*alpha_(j)*input(j,i);
     138        assert(target(j));
     139      }
     140      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
    150141    }
    151142   
     143    for (size_t i = 0; i<prediction.columns(); i++)
     144      prediction(1,i) = -prediction(0,i);
    152145  }
    153146
  • trunk/yat/classifier/SVM.h

    r1100 r1102  
    2727*/
    2828
    29 #include "KernelLookup.h"
    3029#include "SVindex.h"
    3130#include "Target.h"
     
    3938namespace classifier { 
    4039
    41   class DataLookup2D;
    42   class Target;
     40  class DataLookup1D;
     41  class DataLookupWeighted1D;
     42  class KernelLookup;
    4343
    4444  ///
     
    6666    virtual ~SVM();
    6767
    68     const DataLookup2D& data(void) const;
     68    //const DataLookup2D& data(void) const;
    6969
    7070    ///
     
    128128       for training.
    129129    */
    130     void predict(const DataLookup2D& input, utility::matrix& predict) const;
     130    void predict(const KernelLookup& input, utility::matrix& predict) const;
    131131
    132132    ///
Note: See TracChangeset for help on using the changeset viewer.