Changeset 1100 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 18, 2008, 5:37:50 AM (15 years ago)
Author:
Peter
Message:

fixes #313 - SVM constructor is void and passing kernel and target in train function instead

Location:
trunk/yat/classifier
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/SVM.cc

    r1098 r1100  
    3939#include <sstream>
    4040#include <stdexcept>
     41#include <string>
    4142#include <utility>
    4243#include <vector>
     
    4647namespace classifier { 
    4748
    48   SVM::SVM(const KernelLookup& kernel, const Target& target)
    49     : alpha_(target.size(),0),
    50       bias_(0),
     49  SVM::SVM(void)
     50    : bias_(0),
    5151      C_inverse_(0),
    52       kernel_(&kernel),
     52      kernel_(NULL),
    5353      margin_(0),
    5454      max_epochs_(100000),
    55       output_(target.size(),0),
    56       owner_(false),
    57       sample_(target.size()),
    58       target_(target),
    59       trained_(false),
    60       tolerance_(0.00000001)
    61   {
    62 #ifndef NDEBUG
    63     assert(kernel.columns()==kernel.rows());
    64     assert(kernel.columns()==alpha_.size());
    65     for (size_t i=0; i<alpha_.size(); i++)
    66       for (size_t j=0; j<alpha_.size(); j++)
    67         assert(kernel(i,j)==kernel(j,i));
    68     for (size_t i=0; i<alpha_.size(); i++)
    69       for (size_t j=0; j<alpha_.size(); j++)
    70         assert((*kernel_)(i,j)==(*kernel_)(j,i));
    71     for (size_t i = 0; i<kernel_->rows(); i++)
    72       for (size_t j = 0; j<kernel_->columns(); j++)
    73         if (std::isnan((*kernel_)(i,j)))
    74           std::cerr << "SVM: Found nan in kernel: " << i << " "
    75                     << j << std::endl;
    76 #endif       
    77   }
     55      tolerance_(0.00000001),
     56      trained_(false)
     57  {
     58  }
     59
    7860
    7961  SVM::~SVM()
    8062  {
    81     if (owner_)
     63    if (kernel_)
    8264      delete kernel_;
    8365  }
    8466
     67
    8568  const utility::vector& SVM::alpha(void) const
    8669  {
     
    8871  }
    8972
     73
    9074  double SVM::C(void) const
    9175  {
    9276    return 1.0/C_inverse_;
    9377  }
     78
    9479
    9580  void SVM::calculate_margin(void)
     
    10388  }
    10489
     90
    10591  const DataLookup2D& SVM::data(void) const
    10692  {
     
    11197  double SVM::kernel_mod(const size_t i, const size_t j) const
    11298  {
     99    assert(kernel_);
     100    assert(i<kernel_->rows());
     101    assert(i<kernel_->columns());
    113102    return i!=j ? (*kernel_)(i,j) : (*kernel_)(i,j) + C_inverse_;
    114103  }
    115104
    116   SVM* SVM::make_classifier(const DataLookup2D& data,
    117                             const Target& target) const
    118   {
    119     SVM* sc=0;
    120     try {
    121       const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
    122       assert(data.rows()==data.columns());
    123       assert(data.columns()==target.size());
    124       sc = new SVM(kernel,target);
    125       //Copy those variables possible to modify from outside
    126       sc->set_C(this->C());
    127       sc->max_epochs(max_epochs());
    128     }
    129     catch (std::bad_cast) {
    130       std::string str =
    131         "Error in SVM::make_classifier: DataLookup2D of unexpected class.";
    132       throw std::runtime_error(str);
    133     }
    134  
    135     return sc;
    136   }
     105
     106  SVM* SVM::make_classifier(void) const
     107  {
     108    return new SVM(*this);
     109  }
     110
    137111
    138112  long int SVM::max_epochs(void) const
     
    196170  }
    197171
    198   void SVM::reset(void)
    199   {
    200     trained_=false;
    201     alpha_ = utility::vector(target_.size(), 0);
    202   }
    203 
    204172  int SVM::target(size_t i) const
    205173  {
     174    assert(i<target_.size());
    206175    return target_.binary(i) ? 1 : -1;
    207176  }
    208177
    209   void SVM::train(void)
    210   {
     178  void SVM::train(const KernelLookup& kernel, const Target& targ)
     179  {
     180    if (kernel_)
     181      delete kernel_;
     182    kernel_ = new KernelLookup(kernel);
     183    target_ = targ;
     184   
     185    alpha_ = utility::vector(targ.size(), 0.0);
     186    output_ = utility::vector(targ.size(), 0.0);
    211187    // initializing variables for optimization
    212188    assert(target_.size()==kernel_->rows());
  • trunk/yat/classifier/SVM.h

    r1087 r1100  
    5757  public:
    5858    ///
    59     /// Constructor taking the kernel and the target vector as
    60     /// input.
    61     ///
    62     /// @note if the @a target or @a kernel
    63     /// is destroyed the behaviour is undefined.
    64     ///
    65     SVM(const KernelLookup& kernel, const Target& target);
     59    /// \brief Constructor
     60    ///
     61    SVM(void);
    6662
    6763    ///
     
    7369
    7470    ///
    75     /// If DataLookup2D is not a KernelLookup a bad_cast exception is thrown.
    76     ///
    77     SVM* make_classifier(const DataLookup2D&, const Target&) const;
     71    ///
     72    ///
     73    SVM* make_classifier(void) const;
    7874
    7975    ///
     
    145141    ///
    146142    double predict(const DataLookupWeighted1D& input) const;
    147 
    148     ///
    149     /// @brief Function sets \f$ \alpha=0 \f$ and makes SVM untrained.
    150     ///
    151     void reset(void);
    152143
    153144    ///
     
    187178       @return true if succesful
    188179    */
    189     void train();
     180    void train(const KernelLookup& kernel, const Target& target);
    190181
    191182       
     
    243234    unsigned long int max_epochs_;
    244235    utility::vector output_;
    245     bool owner_;
    246236    SVindex sample_;
    247237    Target target_;
    248     bool trained_;
    249238    double tolerance_;
     239    bool trained_;
    250240
    251241  };
  • trunk/yat/classifier/SVindex.cc

    r1004 r1100  
    6969    nof_sv_=0;
    7070    size_t nof_nsv=0;
     71    vec_.resize(alpha.size());
    7172    for (size_t i=0; i<alpha.size(); i++)
    7273      if (alpha(i)<tol){
  • trunk/yat/classifier/Target.cc

    r1004 r1100  
    4141namespace yat {
    4242namespace classifier {
     43
     44  Target::Target(void)
     45  {
     46  }
     47 
    4348
    4449  Target::Target(const std::vector<std::string>& label)
  • trunk/yat/classifier/Target.h

    r1000 r1100  
    4242  /// @brief Class for containing sample labels.
    4343  ///
    44 
    4544  class Target
    4645  {
    4746 
    4847  public:
     48    /**
     49      \brief default constructor
     50    */
     51    Target(void);
     52
    4953    ///
    5054    /// @brief Constructor creating target with @a labels
Note: See TracChangeset for help on using the changeset viewer.