Changeset 828


Ignore:
Timestamp:
Mar 19, 2007, 11:15:43 PM (15 years ago)
Author:
Peter
Message:

Generalized ConsenusInputRanker?, Fixes #151

Location:
trunk
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/test/consensus_inputranker_test.cc

    r820 r828  
    2828#include "yat/classifier/CrossValidationSampler.h"
    2929#include "yat/classifier/IRRank.h"
     30#include "yat/statistics/VectorFunction.h"
    3031
    3132#include <cstdlib>
     
    6364  *error << "Building Consensus_Inputranker" << std::endl;
    6465  theplu::yat::classifier::IRRank retrieve;
    65   theplu::yat::classifier::ConsensusInputRanker cir(sampler,data,roc,retrieve);
    66   *error << "Done" << std::endl;
     66  theplu::yat::statistics::Median median;
     67  theplu::yat::classifier::ConsensusInputRanker cir(retrieve,median);
     68  cir.add(sampler,data,roc);
    6769
     70  *error << "test ids... ";
    6871  if (cir.id(0)!=2 || cir.id(1)!=0 || cir.id(2)!=1){
    69     *error << "incorrect id" << endl;
    70     ok = false;
    71   }
    72  
    73   if (cir.rank(0)!=1 || cir.rank(1)!=2 || cir.rank(2)!=0){
    74     *error << "incorrect rank" << endl;
     72    *error << "\nincorrect id for weighted" << endl;
    7573    ok=false;
    7674  }
     75  else
     76    *error << "ok." << std::endl;
     77
     78  *error << "test ranks... ";
     79  if (cir.rank(0)!=1 || cir.rank(1)!=2 || cir.rank(2)!=0){
     80    *error << "\nincorrect rank for weighted" << endl;
     81    ok=false;
     82  }
     83  else
     84    *error << "ok." << std::endl;
    7785
    7886  theplu::yat::utility::matrix flag(data.rows(),data.columns(),1);
    7987  // Peter, fix weighted version instead
    80   theplu::yat::classifier::ConsensusInputRanker cir2(sampler,data,roc,retrieve);
     88  theplu::yat::classifier::ConsensusInputRanker cir2(retrieve,median);
     89  cir2.add(sampler,data,roc);
    8190
     91  *error << "test ids... ";
    8292  if (cir2.id(0)!=2 || cir2.id(1)!=0 || cir2.id(2)!=1){
    83     *error << "incorrect id for weighted" << endl;
     93    *error << "\nincorrect id for weighted" << endl;
    8494    ok=false;
    8595  }
     96  else
     97    *error << "ok." << std::endl;
    8698 
     99  *error << "test ranks... ";
    87100  if (cir2.rank(0)!=1 || cir2.rank(1)!=2 || cir2.rank(2)!=0){
    88     *error << "incorrect rank for weighted" << endl;
     101    *error << "\nincorrect rank for weighted" << endl;
    89102    ok=false;
    90103  }
     104  else
     105    *error << "ok." << std::endl;
    91106
    92107  if (error!=&std::cerr)
  • trunk/yat/classifier/ConsensusInputRanker.cc

    r817 r828  
    3131#include "yat/statistics/Score.h"
    3232#include "yat/statistics/utility.h"
     33#include "yat/statistics/VectorFunction.h"
    3334#include "yat/utility/stl_utility.h"
    3435
     
    4445namespace classifier { 
    4546
    46   ConsensusInputRanker::ConsensusInputRanker(const IRRetrieve& retriever)
    47     : retriever_(retriever)
     47  ConsensusInputRanker::ConsensusInputRanker(const IRRetrieve& retriever,
     48                                             const statistics::VectorFunction&
     49                                             vf)
     50    : retriever_(retriever), vec_func_(vf)
    4851  {
    4952  }
    5053
    5154
    52   ConsensusInputRanker::ConsensusInputRanker(const Sampler& sampler,
    53                                              const MatrixLookup& data,
    54                                              statistics::Score& score,
    55                                              const IRRetrieve& retriever)
    56     : retriever_(retriever)
     55  void ConsensusInputRanker::add(const Sampler& sampler,
     56                                 const MatrixLookup& data,
     57                                 statistics::Score& score)
    5758  {
    5859    assert(sampler.size());
     60    assert(id_.empty() || id_.size()==data.rows());
     61    input_rankers_.reserve(sampler.size()+input_rankers_.size());
    5962    id_.resize(data.rows());
    6063    rank_.resize(data.rows());
     
    6770  }
    6871
    69   ConsensusInputRanker::ConsensusInputRanker(const Sampler& sampler,
    70                                              const MatrixLookupWeighted& data,
    71                                              statistics::Score& score,
    72                                              const IRRetrieve& retriever)
    73     : retriever_(retriever)
     72  void ConsensusInputRanker::add(const Sampler& sampler,
     73                                 const MatrixLookupWeighted& data,
     74                                 statistics::Score& score)
    7475  {
    75     assert(sampler.size());
     76    assert(id_.empty() || id_.size()==data.rows());
    7677    id_.resize(data.rows());
    7778    rank_.resize(data.rows());
    78  
    7979    for (size_t i=0; i<sampler.size(); ++i){
    8080      input_rankers_.push_back(InputRanker(MatrixLookupWeighted(data,sampler.training_index(i), false),
     
    8787  void ConsensusInputRanker::add(const InputRanker& ir)
    8888  {
     89    assert(id_.empty() || id_.size()==ir.id().size());
    8990    input_rankers_.push_back(ir);
    9091  }
     
    9293  size_t ConsensusInputRanker::id(size_t i) const
    9394  {
     95    assert(i<id_.size());
    9496    return id_[i];
    9597  }
     
    106108  }
    107109
     110
     111  void ConsensusInputRanker::reserve(size_t n)
     112  {
     113    input_rankers_.reserve(n);
     114  }
     115
     116
    108117  void ConsensusInputRanker::update(void)
    109118  {
    110 
    111     // Sorting with respect to median info (from retriever_)
    112     std::vector<std::pair<double,size_t> > medians(id_.size());
     119    // Sorting with respect to VectorFunction(info) where info is a
     120    // vector and each element contains infomation retrieved with
     121    // retriever_ from each InputRanker
     122    std::vector<std::pair<double,size_t> > cons_rank;
     123    cons_rank.reserve(id_.size());
    113124    for (size_t i=0; i<id_.size(); i++){
    114125      std::vector<double> scores;
     
    117128        scores.push_back(retriever_(input_rankers_[j],i));
    118129      }
    119       medians[i].first = statistics::median(scores);
    120       medians[i].second = i;
     130      cons_rank.push_back(std::make_pair(vec_func_(scores), i));
    121131    }
    122132   
    123     //sort medians and assign id_ and rank_
    124     sort(medians.begin(), medians.end(),
     133    //sort cons_rank and assign id_ and rank_
     134    sort(cons_rank.begin(), cons_rank.end(),
    125135         std::greater<std::pair<double, size_t> >());
    126136         
    127     for (size_t i=0; i<medians.size(); i++){
    128       id_[i]=medians[i].second;
     137    for (size_t i=0; i<cons_rank.size(); i++){
     138      assert(i<id_.size());
     139      id_[i]=cons_rank[i].second;
     140      assert(id_[i]<rank_.size());
    129141      rank_[id_[i]]=i;
    130142    }
  • trunk/yat/classifier/ConsensusInputRanker.h

    r817 r828  
    2727#include "InputRanker.h"
    2828
     29#include <vector>
     30
    2931namespace theplu {
    3032namespace yat {
    31 
    32   class statistics::Score;
    33 
     33namespace statistics {
     34  class Score;
     35  class VectorFunction;
     36}
    3437namespace classifier { 
    3538
     
    4649  /// could be different because they are based upon different
    4750  /// sub-sets of the data, or the different lists could be different
    48   /// because they have are generated using different criteria. Having
     51  /// because they have been generated using different criteria. Having
    4952  /// \f$ N \f$ lists means each row in the data matrix has \f$ N \f$
    50   /// ranks (each corresponding to one list) and a consensus ranked
    51   /// list is created by sorting the data rows with respect to their
    52   /// median rank.
     53  /// ranks (each corresponding to one list). A
     54  /// statistics::VectorFunction is used to boil down these ranks to
     55  /// one consensus rank, and a ranked list is created by sorting the
     56  /// data rows with respect to this consensus rank.
    5357  ///
    5458  /// For the time being there are two ways to build a
    5559  /// ConsensusInputRanker. 1) Sending a Sampler and a MatrixLookup to
    56   /// the constructor will create one ranked list for each of the
     60  /// the add function will create one ranked list for each of the
    5761  /// partitions defined in the Sampler. 2) You can generate
    5862  /// your ranked list outside, using your favourite method, and
     
    7175    /// Truly does nothing but creates a few empty member vectors.
    7276    ///
    73     ConsensusInputRanker(const IRRetrieve&);
     77    ConsensusInputRanker(const IRRetrieve&, const statistics::VectorFunction&);
    7478   
    7579    ///
     
    7983    /// the median rank (i.e. update() is called).
    8084    ///
    81     ConsensusInputRanker(const Sampler& sampler, const MatrixLookup&,
    82                          statistics::Score& s, const IRRetrieve&);
     85    void add(const Sampler& sampler, const MatrixLookup&, statistics::Score& s);
    8386   
    8487    ///
     88    /// @brief Add a set of InputRankers
     89    ///
    8590    /// Iterating through @a sampler creating subsets of @a data, and
    8691    /// for each subset is an InputRanker is created using the @a
     
    8893    /// the median rank (i.e. update() is called).
    8994    ///
    90     ConsensusInputRanker(const Sampler& sampler,
    91                          const MatrixLookupWeighted& data,
    92                          statistics::Score& score, const IRRetrieve&);
     95    void add(const Sampler& sampler, const MatrixLookupWeighted& data,
     96             statistics::Score& score);
    9397   
    9498    ///
    95     /// @brief add an InputRanker
     99    /// @brief Add an InputRanker
    96100    ///
    97101    /// @note update() must be called to make the added InputRanker to
     
    119123    size_t rank(size_t i) const;
    120124   
     125    /**
     126       \brief \brief reserve memory for internal vector of InputRankers
     127
     128       This function is recommended before adding using add(const
     129       InputRanker&) to avoid re-allocations.
     130    */
     131    void reserve(size_t n);
     132
     133
    121134    ///
    122135    /// update ids and ranks
     
    131144    std::vector<size_t> rank_;
    132145    const IRRetrieve& retriever_;
    133 
     146    const statistics::VectorFunction& vec_func_;
    134147  };
    135148
  • trunk/yat/statistics/VectorFunction.h

    r827 r828  
    4141
    4242
     43  ///
     44  /// @brief Larget element
     45  ///
    4346  struct Max : public VectorFunction
    4447  {
     
    5053
    5154
     55  ///
     56  /// @brief Median element
     57  ///
    5258  struct Median : public VectorFunction
    5359  {
     60    ///
     61    /// \see statistics::median(std::vector<double>, bool)
    5462    ///
    5563    /// \return Median
     
    5866  };
    5967
    60 
     68  ///
     69  /// \brief Mean element
     70  ///
    6171  struct Mean : public VectorFunction
    6272  {
     
    6878
    6979
     80  ///
     81  /// \brief Smallest element
     82  ///
    7083  struct Min : public VectorFunction
    7184  {
Note: See TracChangeset for help on using the changeset viewer.