Changeset 569


Ignore:
Timestamp:
Mar 23, 2006, 10:38:35 AM (16 years ago)
Author:
Peter
Message:

modified prediction from SVM

Location:
trunk/lib/classifier
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/lib/classifier/SVM.cc

    r568 r569  
    1313#include <algorithm>
    1414#include <cassert>
     15#include <cctype>
    1516#include <cmath>
    1617#include <limits>
     
    2829      C_inverse_(0),
    2930      kernel_(&kernel),
     31      margin_(0),
    3032      max_epochs_(100000),
    3133      output_(target.size(),0),
     
    5759      bias_(0),
    5860      C_inverse_(0),
     61      margin_(0),
    5962      max_epochs_(10000000),
    6063      output_(target.size(),0),
     
    97100
    98101
     102  void SVM::calculate_margin(void)
     103  {
     104    margin_ = 0;
     105    for(size_t i = 0; i<alpha_.size(); ++i){
     106      margin_ += alpha_(i)*target(i)*kernel_mod(i,i)*alpha_(i)*target(i);
     107      for(size_t j = i+1; j<alpha_.size(); ++j)
     108        margin_ += 2*alpha_(i)*target(i)*kernel_mod(i,j)*alpha_(j)*target(j);
     109    }
     110  }
     111
     112
    99113  SupervisedClassifier* SVM::make_classifier(const DataLookup2D& data,
    100114                                             const Target& target) const
     
    138152        assert(target(j));
    139153      }
    140       prediction(0,i) += bias_;
     154      prediction(0,i) = margin_ * (prediction(0,i) + bias_);
    141155    }
    142156
     
    155169      y += alpha_(i)*target_(i)*kernel_->element(x,i);
    156170
    157     return y+bias_;
     171    return margin_*(y+bias_);
    158172  }
    159173
     
    164178      y += alpha_(i)*target_(i)*kernel_->element(x,w,i);
    165179
    166     return y+bias_;
     180    return margin_*(y+bias_);
    167181  }
    168182
     
    247261        std::cerr << "WARNING: SVM: maximal number of epochs reached.\n";
    248262        calculate_bias();
     263        calculate_margin();
    249264        return false;
    250265      }
    251266    }
    252    
     267    calculate_margin();
    253268    trained_ = calculate_bias();
    254269    return trained_;
  • trunk/lib/classifier/SVM.h

    r568 r569  
    171171
    172172    ///
    173     /// Generate prediction @a output from @a input. The prediction is
    174     /// returned in @a output. The output has 2 rows. The first row is
    175     /// for binary target true, and the second is for binary target
    176     /// false. The second row is superfluous because it the first row
     173    /// Generate prediction @a predict from @a input. The prediction
     174    /// is calculated as the output times the margin, i.e., geometric
     175    /// distance from decision hyperplane: \f$ \frac{ \sum \alpha_j
     176    /// t_j K_{ij} + bias}{w} \f$ The output has 2 rows. The first row
     177    /// is for binary target true, and the second is for binary target
     178    /// false. The second row is superfluous as it is the first row
    177179    /// negated. It exist just to be aligned with multi-class
    178180    /// SupervisedClassifiers. Each column in @a input and @a output
     
    182184    /// for training.
    183185    ///
    184     void predict(const DataLookup2D& input, gslapi::matrix& output) const;
    185 
    186     ///
    187     /// @return output from data @a input
     186    /// @note
     187    ///
     188    void predict(const DataLookup2D& input, gslapi::matrix& predict) const;
     189
     190    ///
     191    /// @return output times margin (i.e. geometric distance from
     192    /// decision hyperplane) from data @a input
    188193    ///
    189194    double predict(const DataLookup1D& input) const;
    190195
    191196    ///
    192     /// @return output from data @a input with corresponding @a weight
     197    /// @return output times margin from data @a input with
     198    /// corresponding @a weight
    193199    ///
    194200    double predict(const DataLookup1D& input, const DataLookup1D& weight) const;
     
    230236
    231237    ///
     238    /// Calculate margin that is inverse of w
     239    ///
     240    void calculate_margin(void);
     241
     242    ///
    232243    ///   Private function choosing which two elements that should be
    233244    ///   updated. First checking for the biggest violation (output - target =
     
    254265    double C_inverse_;
    255266    const KernelLookup* kernel_;
     267    double margin_;
    256268    unsigned long int max_epochs_;
    257269    gslapi::vector output_;
     
    260272    bool trained_;
    261273    double tolerance_;
    262    
    263274
    264275  };
Note: See TracChangeset for help on using the changeset viewer.