Ignore:
Timestamp:
Mar 11, 2006, 11:21:27 PM (17 years ago)
Author:
Peter
Message:

some changes in EB

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/lib/classifier/SVM.cc

    r552 r559  
    3535      tolerance_(0.00000001)
    3636  {
     37#ifndef NDEBUG
     38    for (size_t i=0; i<alpha_.size(); i++)
     39      for (size_t j=0; j<alpha_.size(); j++)
     40        assert(kernel(i,j)==kernel(j,i));
     41    for (size_t i=0; i<alpha_.size(); i++)
     42      for (size_t j=0; j<alpha_.size(); j++)
     43        assert((*kernel_)(i,j)==(*kernel_)(j,i));
     44    for (size_t i = 0; i<kernel_->rows(); i++)
     45      for (size_t j = 0; j<kernel_->columns(); j++)
     46        if (std::isnan((*kernel_)(i,j)))
     47          std::cerr << "SVM: Found nan in kernel: " << i << " "
     48                    << j << std::endl;
     49#endif       
     50
    3751  }
    3852
     
    8498                                             const Target& target) const
    8599  {
     100    assert(data.rows()==data.columns());
     101    assert(data.columns()==target.size());
    86102    // Peter, should check success of dynamic_cast
    87103    const KernelLookup& tmp = dynamic_cast<const KernelLookup&>(data);
     
    92108      sc = new SVM(tmp,target);
    93109
     110
    94111    //Copy those variables possible to modify from outside
    95112    return sc;
     
    99116  {
    100117    // Peter, should check success of dynamic_cast
    101     const KernelLookup* input_kernel= &dynamic_cast<const KernelLookup&>(input);
    102 
     118    const KernelLookup input_kernel = dynamic_cast<const KernelLookup&>(input);
     119
     120    const KernelLookup* kernel_pointer;
    103121    if (ranker_) {// feature selection
    104122      std::vector<size_t> index;
     
    106124      for (size_t i=0; i<nof_inputs_; i++)
    107125        index.push_back(ranker_->id(i));
    108       input_kernel = input_kernel->selected(index);
    109     }
     126      kernel_pointer = input_kernel.selected(index);
     127    }
     128    else
     129      kernel_pointer = &input_kernel;
    110130
    111131    assert(input.rows()==alpha_.size());
    112132    prediction = gslapi::matrix(2,input.columns(),0);
    113133    for (size_t i = 0; i<input.columns(); i++){
    114       for (size_t j = 0; j<input.rows(); j++)
    115         prediction(0,i) += target(j)*alpha_(i)*(*input_kernel)(j,i);
     134      for (size_t j = 0; j<input.rows(); j++){
     135        prediction(0,i) += target(j)*alpha_(j)*(*kernel_pointer)(j,i);
     136        assert(target(j));
     137      }
    116138      prediction(0,i) += bias_;
    117139    }
     
    119141    for (size_t i = 0; i<prediction.columns(); i++)
    120142      prediction(1,i) = -prediction(0,i);
     143   
     144    if (ranker_)
     145      delete kernel_pointer;
     146    assert(prediction(0,0));
    121147  }
    122148
     
    150176      for (size_t j=0; j<E.size(); j++)
    151177        E(i) += kernel_mod(i,j)*target(j)*alpha_(j);
    152       E(i)=E(i)-target(i);
     178      E(i)-=target(i);
    153179    }
    154180    assert(target_.size()==E.size());
     
    170196      double alpha_old1=alpha_(sample_.value_first());
    171197      double alpha_old2=alpha_(sample_.value_second());
    172 
    173198      alpha_new2 = ( alpha_(sample_.value_second()) +
    174199                     target(sample_.value_second())*
     
    179204      else if (alpha_new2<u)
    180205        alpha_new2 = u;
    181      
    182206     
    183207      // Updating the alphas
     
    254278      assert(alpha_(sample_.value_second())>tolerance_);
    255279
    256      
    257280      if (E(sample_.value_second()) - E(sample_.value_first()) > 2*tolerance_){
    258281        return true;
     
    279302    // if no support vectors - special case
    280303    else{
     304      // to avoid getting stuck we shuffle
     305      sample_.shuffle();
    281306      for (size_t i=0; i<sample_.n(); i++) {
    282307        if (target_.binary(sample_(i))){
Note: See TracChangeset for help on using the changeset viewer.