Changeset 1100


Ignore:
Timestamp:
Feb 18, 2008, 5:37:50 AM (16 years ago)
Author:
Peter
Message:

fixes #313 - SVM constructor is void and passing kernel and target in train function instead

Location:
trunk
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/ensemble_test.cc

    r1087 r1100  
    3636#include "yat/classifier/SVM.h"
    3737#include "yat/statistics/AUC.h"
     38#include "yat/statistics/EuclideanDistance.h"
    3839
    3940#include <cassert>
     
    7778  assert(data.columns()==target.size());
    7879
     80  {
     81    *error << "create ensemble of ncc" << std::endl;
     82    classifier::NCC<statistics::EuclideanDistance> ncc(data, target);
     83    classifier::CrossValidationSampler sampler(target,3,3);
     84    classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
     85    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
     86      classifier::MatrixLookup> ensemble(ncc, data, sampler);
     87    *error << "build ensemble" << std::endl;
     88    ensemble.build();
     89  }
     90
    7991  *error << "create KernelLookup" << std::endl;
    8092  classifier::KernelLookup kernel_lookup(kernel);
    8193  *error << "create svm" << std::endl;
    82   classifier::SVM svm(kernel_lookup, target);
     94  classifier::SVM svm();
    8395  *error << "create Subsets" << std::endl;
    8496  classifier::CrossValidationSampler sampler(target,3,3);
    8597  classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
    8698                                                           kernel_lookup);
     99  /* Peter, temporarily removed because of redesign
    87100  *error << "create ensemble" << std::endl;
    88101  classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup>
     
    96109  statistics::AUC roc;
    97110  *error << roc.score(target,out) << std::endl;
    98 
     111  */
    99112  delete kf;
    100113
  • trunk/test/svm_test.cc

    r1042 r1100  
    8282  *error << "testing with linear kernel" << std::endl;
    8383  assert(kv2.rows()==target2.size());
    84   classifier::SVM classifier2(kv2, target2);
     84  classifier::SVM classifier2;
    8585  *error << "training...";
    86   classifier2.train();
     86  classifier2.train(kv2, target2);
    8787  *error << " done!" << std::endl;
    8888
     
    130130
    131131  classifier::KernelLookup kv(kernel);
    132   theplu::yat::classifier::SVM svm(kv, target);
    133   svm.train();
     132  theplu::yat::classifier::SVM svm;
     133  svm.train(kv, target);
    134134
    135135  theplu::yat::utility::vector alpha = svm.alpha();
  • trunk/yat/classifier/SVM.cc

    r1098 r1100  
    3939#include <sstream>
    4040#include <stdexcept>
     41#include <string>
    4142#include <utility>
    4243#include <vector>
     
    4647namespace classifier { 
    4748
    48   SVM::SVM(const KernelLookup& kernel, const Target& target)
    49     : alpha_(target.size(),0),
    50       bias_(0),
     49  SVM::SVM(void)
     50    : bias_(0),
    5151      C_inverse_(0),
    52       kernel_(&kernel),
     52      kernel_(NULL),
    5353      margin_(0),
    5454      max_epochs_(100000),
    55       output_(target.size(),0),
    56       owner_(false),
    57       sample_(target.size()),
    58       target_(target),
    59       trained_(false),
    60       tolerance_(0.00000001)
    61   {
    62 #ifndef NDEBUG
    63     assert(kernel.columns()==kernel.rows());
    64     assert(kernel.columns()==alpha_.size());
    65     for (size_t i=0; i<alpha_.size(); i++)
    66       for (size_t j=0; j<alpha_.size(); j++)
    67         assert(kernel(i,j)==kernel(j,i));
    68     for (size_t i=0; i<alpha_.size(); i++)
    69       for (size_t j=0; j<alpha_.size(); j++)
    70         assert((*kernel_)(i,j)==(*kernel_)(j,i));
    71     for (size_t i = 0; i<kernel_->rows(); i++)
    72       for (size_t j = 0; j<kernel_->columns(); j++)
    73         if (std::isnan((*kernel_)(i,j)))
    74           std::cerr << "SVM: Found nan in kernel: " << i << " "
    75                     << j << std::endl;
    76 #endif       
    77   }
     55      tolerance_(0.00000001),
     56      trained_(false)
     57  {
     58  }
     59
    7860
    7961  SVM::~SVM()
    8062  {
    81     if (owner_)
     63    if (kernel_)
    8264      delete kernel_;
    8365  }
    8466
     67
    8568  const utility::vector& SVM::alpha(void) const
    8669  {
     
    8871  }
    8972
     73
    9074  double SVM::C(void) const
    9175  {
    9276    return 1.0/C_inverse_;
    9377  }
     78
    9479
    9580  void SVM::calculate_margin(void)
     
    10388  }
    10489
     90
    10591  const DataLookup2D& SVM::data(void) const
    10692  {
     
    11197  double SVM::kernel_mod(const size_t i, const size_t j) const
    11298  {
     99    assert(kernel_);
     100    assert(i<kernel_->rows());
     101    assert(i<kernel_->columns());
    113102    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
    114103  }
    115104
    116   SVM* SVM::make_classifier(const DataLookup2D& data,
    117                             const Target& target) const
    118   {
    119     SVM* sc=0;
    120     try {
    121       const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
    122       assert(data.rows()==data.columns());
    123       assert(data.columns()==target.size());
    124       sc = new SVM(kernel,target);
    125       //Copy those variables possible to modify from outside
    126       sc->set_C(this->C());
    127       sc->max_epochs(max_epochs());
    128     }
    129     catch (std::bad_cast) {
    130       std::string str =
    131         "Error in SVM::make_classifier: DataLookup2D of unexpected class.";
    132       throw std::runtime_error(str);
    133     }
    134  
    135     return sc;
    136   }
     105
     106  SVM* SVM::make_classifier(void) const
     107  {
     108    return new SVM(*this);
     109  }
     110
    137111
    138112  long int SVM::max_epochs(void) const
     
    196170  }
    197171
    198   void SVM::reset(void)
    199   {
    200     trained_=false;
    201     alpha_ = utility::vector(target_.size(), 0);
    202   }
    203 
    204172  int SVM::target(size_t i) const
    205173  {
     174    assert(i<target_.size());
    206175    return target_.binary(i) ? 1 : -1;
    207176  }
    208177
    209   void SVM::train(void)
    210   {
     178  void SVM::train(const KernelLookup& kernel, const Target& targ)
     179  {
     180    if (kernel_)
     181      delete kernel_;
     182    kernel_ = new KernelLookup(kernel);
     183    target_ = targ;
     184   
     185    alpha_ = utility::vector(targ.size(), 0.0);
     186    output_ = utility::vector(targ.size(), 0.0);
    211187    // initializing variables for optimization
    212188    assert(target_.size()==kernel_->rows());
  • trunk/yat/classifier/SVM.h

    r1087 r1100  
    5757  public:
    5858    ///
    59     /// Constructor taking the kernel and the target vector as
    60     /// input.
    61     ///
    62     /// @note if the @a target or @a kernel
    63     /// is destroyed the behaviour is undefined.
    64     ///
    65     SVM(const KernelLookup& kernel, const Target& target);
     59    /// \brief Constructor
     60    ///
     61    SVM(void);
    6662
    6763    ///
     
    7369
    7470    ///
    75     /// If DataLookup2D is not a KernelLookup a bad_cast exception is thrown.
    76     ///
    77     SVM* make_classifier(const DataLookup2D&, const Target&) const;
     71    ///
     72    ///
     73    SVM* make_classifier(void) const;
    7874
    7975    ///
     
    145141    ///
    146142    double predict(const DataLookupWeighted1D& input) const;
    147 
    148     ///
    149     /// @brief Function sets \f$ \alpha=0 \f$ and makes SVM untrained.
    150     ///
    151     void reset(void);
    152143
    153144    ///
     
    187178       @return true if succesful
    188179    */
    189     void train();
     180    void train(const KernelLookup& kernel, const Target& target);
    190181
    191182       
     
    243234    unsigned long int max_epochs_;
    244235    utility::vector output_;
    245     bool owner_;
    246236    SVindex sample_;
    247237    Target target_;
    248     bool trained_;
    249238    double tolerance_;
     239    bool trained_;
    250240
    251241  };
  • trunk/yat/classifier/SVindex.cc

    r1004 r1100  
    6969    nof_sv_=0;
    7070    size_t nof_nsv=0;
     71    vec_.resize(alpha.size());
    7172    for (size_t i=0; i<alpha.size(); i++)
    7273      if (alpha(i)<tol){
  • trunk/yat/classifier/Target.cc

    r1004 r1100  
    4141namespace yat {
    4242namespace classifier {
     43
     44  Target::Target(void)
     45  {
     46  }
     47 
    4348
    4449  Target::Target(const std::vector<std::string>& label)
  • trunk/yat/classifier/Target.h

    r1000 r1100  
    4242  /// @brief Class for containing sample labels.
    4343  ///
    44 
    4544  class Target
    4645  {
    4746 
    4847  public:
     48    /**
     49      \brief default constructor
     50    */
     51    Target(void);
     52
    4953    ///
    5054    /// @brief Constructor creating target with @a labels
  • trunk/yat/utility/yat_assert.h

    r1000 r1100  
    3838  template<class X> inline void yat_assert(bool assertion, std::string msg="")
    3939#ifdef YAT_DEBUG
    40   { if (YAT_DEBUG && !assertion) throw X(msg); }
     40  { if (!assertion) throw X(std::string("yat_assert:")+msg); }
    4141#else
    4242  { }
Note: See TracChangeset for help on using the changeset viewer.