Changeset 559


Ignore:
Timestamp:
Mar 11, 2006, 11:21:27 PM (16 years ago)
Author:
Peter
Message:

some changes in EB

Location:
trunk
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/lib/classifier/EnsembleBuilder.cc

    r531 r559  
    55#include <c++_tools/classifier/CrossSplitter.h>
    66#include <c++_tools/classifier/DataLookup2D.h>
     7#include <c++_tools/classifier/KernelLookup.h>
    78#include <c++_tools/classifier/SupervisedClassifier.h>
    89#include <c++_tools/classifier/Target.h>
     
    3132      const DataLookup2D& training=cross_splitter_.training_data();
    3233      const Target& targets=cross_splitter_.training_target();
     34
    3335      SupervisedClassifier* classifier=
    3436        mother_.make_classifier(training,targets);
     
    5254    size_t k=0;
    5355    gslapi::matrix prediction;   
    54     while(cross_splitter_.more()) {
    55       classifier(k++).predict(data,prediction);
     56   
    5657
    57       for(size_t i=0; i<prediction.rows();i++)
    58         for(size_t j=0; j<prediction.columns();j++)
    59           result[i][j].add(prediction(i,j));
    60 
    61       cross_splitter_.next();
     58    try {
     59      const KernelLookup& kernel = dynamic_cast<const KernelLookup&>(data);
     60      while(cross_splitter_.more()) {
     61        classifier(k++).predict(KernelLookup(kernel,
     62                                             cross_splitter_.training_index(),
     63                                             true),
     64                                prediction);
     65        for(size_t i=0; i<prediction.rows();i++)
     66          for(size_t j=0; j<prediction.columns();j++)
     67            result[i][j].add(prediction(i,j));
     68        cross_splitter_.next();
     69      }
    6270    }
    63 
     71    catch (std::bad_cast) {
     72      while(cross_splitter_.more()) {
     73        classifier(k++).predict(data,prediction);
     74        for(size_t i=0; i<prediction.rows();i++)
     75          for(size_t j=0; j<prediction.columns();j++)
     76            result[i][j].add(prediction(i,j));
     77       
     78        cross_splitter_.next();
     79      }
     80    }
    6481  }
    6582
  • trunk/lib/classifier/KernelLookup.cc

    r553 r559  
    5858 
    5959
     60  KernelLookup::KernelLookup(const KernelLookup& kl,
     61                             const std::vector<size_t>& index,
     62                             const bool row)
     63    : DataLookup2D(kl,index,row), kernel_(kl.kernel_)
     64  {
     65    // Checking that no index is out of range
     66    assert(row_index_.empty() ||
     67           *(max_element(row_index_.begin(), row_index_.end()))<kernel_->rows());
     68    assert(column_index_.empty() ||
     69           *(max_element(column_index_.begin(), column_index_.end()))<
     70           kernel_->columns());
     71
     72  }
     73 
     74
    6075  KernelLookup::~KernelLookup(void)
    6176  {
  • trunk/lib/classifier/KernelLookup.h

    r555 r559  
    6868                 const std::vector<size_t>& column);
    6969   
     70    ///
     71    /// Constructor taking the column (default) or row index vector as
     72    /// input. If @a row is false the created KernelLookup will have
     73    /// equally many rows as @a kernel.
     74    ///
     75    /// @note If underlying matrix goes out of scope or is deleted, the
     76    /// KernelLookup becomes invalid and the result of further use is
     77    /// undefined.
     78    ///
     79    KernelLookup(const KernelLookup& kernel, const std::vector<size_t>&,
     80                 const bool row=false);
     81
    7082    ///
    7183    /// @brief Destructor
  • trunk/lib/classifier/SVM.cc

    r552 r559  
    3535      tolerance_(0.00000001)
    3636  {
     37#ifndef NDEBUG
     38    for (size_t i=0; i<alpha_.size(); i++)
     39      for (size_t j=0; j<alpha_.size(); j++)
     40        assert(kernel(i,j)==kernel(j,i));
     41    for (size_t i=0; i<alpha_.size(); i++)
     42      for (size_t j=0; j<alpha_.size(); j++)
     43        assert((*kernel_)(i,j)==(*kernel_)(j,i));
     44    for (size_t i = 0; i<kernel_->rows(); i++)
     45      for (size_t j = 0; j<kernel_->columns(); j++)
     46        if (std::isnan((*kernel_)(i,j)))
     47          std::cerr << "SVM: Found nan in kernel: " << i << " "
     48                    << j << std::endl;
     49#endif       
     50
    3751  }
    3852
     
    8498                                             const Target& target) const
    8599  {
     100    assert(data.rows()==data.columns());
     101    assert(data.columns()==target.size());
    86102    // Peter, should check success of dynamic_cast
    87103    const KernelLookup& tmp = dynamic_cast<const KernelLookup&>(data);
     
    92108      sc = new SVM(tmp,target);
    93109
     110
    94111    //Copy those variables possible to modify from outside
    95112    return sc;
     
    99116  {
    100117    // Peter, should check success of dynamic_cast
    101     const KernelLookup* input_kernel= &dynamic_cast<const KernelLookup&>(input);
    102 
     118    const KernelLookup input_kernel = dynamic_cast<const KernelLookup&>(input);
     119
     120    const KernelLookup* kernel_pointer;
    103121    if (ranker_) {// feature selection
    104122      std::vector<size_t> index;
     
    106124      for (size_t i=0; i<nof_inputs_; i++)
    107125        index.push_back(ranker_->id(i));
    108       input_kernel = input_kernel->selected(index);
    109     }
     126      kernel_pointer = input_kernel.selected(index);
     127    }
     128    else
     129      kernel_pointer = &input_kernel;
    110130
    111131    assert(input.rows()==alpha_.size());
    112132    prediction = gslapi::matrix(2,input.columns(),0);
    113133    for (size_t i = 0; i<input.columns(); i++){
    114       for (size_t j = 0; j<input.rows(); j++)
    115         prediction(0,i) += target(j)*alpha_(i)*(*input_kernel)(j,i);
     134      for (size_t j = 0; j<input.rows(); j++){
     135        prediction(0,i) += target(j)*alpha_(j)*(*kernel_pointer)(j,i);
     136        assert(target(j));
     137      }
    116138      prediction(0,i) += bias_;
    117139    }
     
    119141    for (size_t i = 0; i<prediction.columns(); i++)
    120142      prediction(1,i) = -prediction(0,i);
     143   
     144    if (ranker_)
     145      delete kernel_pointer;
     146    assert(prediction(0,0));
    121147  }
    122148
     
    150176      for (size_t j=0; j<E.size(); j++)
    151177        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
    152       E(i)=E(i)-target(i);
     178      E(i)-=target(i);
    153179    }
    154180    assert(target_.size()==E.size());
     
    170196      double alpha_old1=alpha_(sample_.value_first());
    171197      double alpha_old2=alpha_(sample_.value_second());
    172 
    173198      alpha_new2 = ( alpha_(sample_.value_second()) +
    174199                     target(sample_.value_second())*
     
    179204      else if (alpha_new2<u)
    180205        alpha_new2 = u;
    181      
    182206     
    183207      // Updating the alphas
     
    254278      assert(alpha_(sample_.value_second())>tolerance_);
    255279
    256      
    257280      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
    258281        return true;
     
    279302    // if no support vectors - special case
    280303    else{
     304      // to avoid getting stuck we shuffle
     305      sample_.shuffle();
    281306      for (size_t i=0; i<sample_.n(); i++) {
    282307        if (target_.binary(sample_(i))){
  • trunk/test/kernel_lookup_test.cc

    r538 r559  
    128128        }
    129129  }
     130 
     131  KernelLookup k5(k1,index_even,index_even);
     132  std::vector<size_t> index5;
     133  index5.push_back(0);
     134  index5.push_back(2);
     135  const KernelLookup* k6 = k5.training_data(index5);
     136  for (size_t s=0; s<k6->rows(); s++)
     137    for (size_t t=0; t<k6->rows(); t++)
     138      ok = ok && ((*k6)(s,t)==(*k6)(t,s));
     139
     140
    130141  if (ok)
    131142    *error << "Ok." << std::endl;
Note: See TracChangeset for help on using the changeset viewer.