Changeset 1157 for trunk/yat/classifier


Ignore:
Timestamp:
Feb 26, 2008, 2:25:19 PM (16 years ago)
Author:
Markus Ringnér
Message:

Refs #318

Location:
trunk/yat/classifier
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/EnsembleBuilder.h

    r1125 r1157  
    153153  {
    154154    for(u_long i=0; i<subset_->size();++i) {
    155       C* classifier = mother_.make_classifier(subset_->training_data(i),
    156                                               subset_->training_target(i));
    157       classifier->train();
     155      C* classifier = mother_.make_classifier();
     156      classifier->train(subset_->training_data(i),
     157                        subset_->training_target(i));
    158158      classifier_.push_back(classifier);
    159159    }   
  • trunk/yat/classifier/KNN.h

    r1156 r1157  
    5757  public:
    5858    ///
    59     /// Constructor taking the training data and the target   
    60     /// as input.
    61     ///
    62     KNN(const MatrixLookup&, const Target&);
    63 
    64 
    65     ///
    66     /// Constructor taking the training data with weights and the
    67     /// target as input.
    68     ///
    69     KNN(const MatrixLookupWeighted&, const Target&);
    70 
     59    /// @brief Constructor
     60    ///
     61    KNN(void);
     62
     63
     64    ///
     65    /// @brief Destructor
     66    ///
    7167    virtual ~KNN();
    7268   
    73     //
    74     // @return the training data
    75     //
    76     const DataLookup2D& data(void) const;
    77 
    7869
    7970    ///
     
    8576
    8677    ///
    87     /// @brief sets the number of neighbors, k. If the number of
    88     /// training samples set is smaller than \a k_in, k is set to the number of
    89     /// training samples.
     78    /// @brief sets the number of neighbors, k.
    9079    ///
    9180    void k(u_int k_in);
    9281
    9382
    94     KNN<Distance,NeighborWeighting>* make_classifier(const DataLookup2D&,
    95                          const Target&) const;
    96    
    97     ///
    98     /// Train the classifier using the training data.
    99     /// This function does nothing but is required by the interface.
    100     ///
    101     void train();
     83    KNN<Distance,NeighborWeighting>* make_classifier(void) const;
     84   
     85    ///
     86    /// Train the classifier using training data and target.
     87    ///
     88    /// If the number of training samples set is smaller than \a k_in,
     89    /// k is set to the number of training samples.
     90    ///
     91    void train(const MatrixLookup&, const Target&);
     92
     93    ///
     94    /// Train the classifier using weighted training data and target.
     95    ///
     96    void train(const MatrixLookupWeighted&, const Target&);
    10297
    10398   
     
    114109    // data_ has to be of type DataLookup2D to accomodate both
    115110    // MatrixLookup and MatrixLookupWeighted
    116     const DataLookup2D& data_;
     111    const DataLookup2D* data_;
     112    const Target* target_;
    117113
    118114    // The number of neighbors
     
    143139 
    144140  template <typename Distance, typename NeighborWeighting>
    145   KNN<Distance, NeighborWeighting>::KNN(const MatrixLookup& data, const Target& target)
    146     : SupervisedClassifier(target), data_(data),k_(3)
    147   {
    148     utility::yat_assert<std::runtime_error>
    149       (data.columns()==target.size(),
    150        "KNN::KNN called with different sizes of target and data");
    151     // k has to be at most the number of training samples.
    152     if(data_.columns()>k_)
    153       k_=data_.columns();
    154   }
    155 
    156 
    157   template <typename Distance, typename NeighborWeighting>
    158   KNN<Distance, NeighborWeighting>::KNN
    159   (const MatrixLookupWeighted& data, const Target& target)
    160     : SupervisedClassifier(target), data_(data),k_(3)
    161   {
    162     utility::yat_assert<std::runtime_error>
    163       (data.columns()==target.size(),
    164        "KNN::KNN called with different sizes of target and data");
    165     if(data_.columns()>k_)
    166       k_=data_.columns();
    167   }
     141  KNN<Distance, NeighborWeighting>::KNN()
     142    : SupervisedClassifier(),data_(0),target_(0),k_(3)
     143  {
     144  }
     145
    168146 
    169147  template <typename Distance, typename NeighborWeighting>
     
    178156    // matrix with training samples as rows and test samples as columns
    179157    utility::Matrix* distances =
    180       new utility::Matrix(data_.columns(),test.columns());
     158      new utility::Matrix(data_->columns(),test.columns());
    181159   
    182160   
     
    186164      // unweighted training data
    187165      if(const MatrixLookup* training_unweighted =
    188          dynamic_cast<const MatrixLookup*>(&data_))
     166         dynamic_cast<const MatrixLookup*>(data_))
    189167        calculate_unweighted(*training_unweighted,*test_unweighted,distances);
    190168      // weighted training data
    191169      else if(const MatrixLookupWeighted* training_weighted =
    192               dynamic_cast<const MatrixLookupWeighted*>(&data_))
     170              dynamic_cast<const MatrixLookupWeighted*>(data_))
    193171        calculate_weighted(*training_weighted,MatrixLookupWeighted(*test_unweighted),
    194172                           distances);             
     
    200178      // unweighted training data
    201179      if(const MatrixLookup* training_unweighted =
    202          dynamic_cast<const MatrixLookup*>(&data_)) {
     180         dynamic_cast<const MatrixLookup*>(data_)) {
    203181        calculate_weighted(MatrixLookupWeighted(*training_unweighted),
    204182                           *test_weighted,distances);
     
    206184      // weighted training data
    207185      else if(const MatrixLookupWeighted* training_weighted =
    208               dynamic_cast<const MatrixLookupWeighted*>(&data_))
     186              dynamic_cast<const MatrixLookupWeighted*>(data_))
    209187        calculate_weighted(*training_weighted,*test_weighted,distances);             
    210188      // Training data can not be of incorrect type
     
    252230    }
    253231  }
    254 
    255  
    256   template <typename Distance, typename NeighborWeighting>
    257   const DataLookup2D& KNN<Distance, NeighborWeighting>::data(void) const
    258   {
    259     return data_;
    260   }
    261232 
    262233 
     
    271242  {
    272243    k_=k;
    273     if(k_>data_.columns())
    274       k_=data_.columns();
    275244  }
    276245
     
    278247  template <typename Distance, typename NeighborWeighting>
    279248  KNN<Distance, NeighborWeighting>*
    280   KNN<Distance, NeighborWeighting>::make_classifier(const DataLookup2D& data,
    281                                                     const Target& target) const
     249  KNN<Distance, NeighborWeighting>::make_classifier() const
    282250  {     
    283     KNN* knn=0;
    284     try {
    285       if(data.weighted()) {
    286         knn=new KNN<Distance, NeighborWeighting>
    287           (dynamic_cast<const MatrixLookupWeighted&>(data),target);
    288       } 
    289       else {
    290         knn=new KNN<Distance, NeighborWeighting>
    291           (dynamic_cast<const MatrixLookup&>(data),target);
    292       }
    293       knn->k(this->k());
    294     }
    295     catch (std::bad_cast) {
    296       std::string str = "Error in KNN<Distance, NeighborWeighting>";
    297       str += "::make_classifier: DataLookup2D of unexpected class.";
    298       throw std::runtime_error(str);
    299     }
     251    KNN* knn=new KNN<Distance, NeighborWeighting>();
     252    knn->k(this->k());
    300253    return knn;
    301254  }
     
    303256 
    304257  template <typename Distance, typename NeighborWeighting>
    305   void KNN<Distance, NeighborWeighting>::train()
     258  void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
     259                                               const Target& target)
    306260  {   
     261    utility::yat_assert<std::runtime_error>
     262      (data.columns()==target.size(),
     263       "KNN::train called with different sizes of target and data");
     264    // k has to be at most the number of training samples.
     265    if(data.columns()<k_)
     266      k_=data.columns();
     267    data_=&data;
     268    target_=&target;
     269    trained_=true;
     270  }
     271
     272  template <typename Distance, typename NeighborWeighting>
     273  void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
     274                                               const Target& target)
     275  {   
     276    utility::yat_assert<std::runtime_error>
     277      (data.columns()==target.size(),
     278       "KNN::train called with different sizes of target and data");
     279    // k has to be at most the number of training samples.
     280    if(data.columns()<k_)
     281      k_=data.columns();
     282    data_=&data;
     283    target_=&target;
    307284    trained_=true;
    308285  }
     
    313290                                                 utility::Matrix& prediction) const
    314291  {   
    315     utility::yat_assert<std::runtime_error>(data_.rows()==test.rows(),"KNN::predict different number of rows in training and test data");
     292    utility::yat_assert<std::runtime_error>(data_->rows()==test.rows(),"KNN::predict different number of rows in training and test data");
    316293
    317294    utility::Matrix* distances=calculate_distances(test);
    318295   
    319     prediction.resize(target_.nof_classes(),test.columns(),0.0);
     296    prediction.resize(target_->nof_classes(),test.columns(),0.0);
    320297    for(size_t sample=0;sample<distances->columns();sample++) {
    321298      std::vector<size_t> k_index;
     
    323300      utility::sort_smallest_index(k_index,k_,dist);
    324301      utility::VectorView pred=prediction.column_view(sample);
    325       weighting_(dist,k_index,target_,pred);
     302      weighting_(dist,k_index,*target_,pred);
    326303    }
    327304    delete distances;
     
    329306    // classes for which there are no training samples should be set
    330307    // to nan in the predictions
    331     for(size_t c=0;c<target_.nof_classes(); c++)
    332       if(!target_.size(c))
     308    for(size_t c=0;c<target_->nof_classes(); c++)
     309      if(!target_->size(c))
    333310        for(size_t j=0;j<prediction.columns();j++)
    334311          prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
  • trunk/yat/classifier/NBC.cc

    r1144 r1157  
    4040namespace classifier {
    4141
    42   NBC::NBC(const MatrixLookup& data, const Target& target)
    43     : SupervisedClassifier(target), data_(data)
     42  NBC::NBC()
     43    : SupervisedClassifier()
    4444  {
    4545  }
    4646
    47   NBC::NBC(const MatrixLookupWeighted& data, const Target& target)
    48     : SupervisedClassifier(target), data_(data)
    49   {
    50   }
    5147
    5248  NBC::~NBC()   
     
    5551
    5652
    57   const DataLookup2D& NBC::data(void) const
    58   {
    59     return data_;
    60   }
    61 
    62 
    63   NBC*
    64   NBC::make_classifier(const DataLookup2D& data, const Target& target) const
     53  NBC* NBC::make_classifier() const
    6554  {     
    66     NBC* nbc=0;
    67     try {
    68       if(data.weighted()) {
    69         nbc=new NBC(dynamic_cast<const MatrixLookupWeighted&>(data),target);
    70       }
    71       else {
    72         nbc=new NBC(dynamic_cast<const MatrixLookup&>(data),target);
    73       }     
    74     }
    75     catch (std::bad_cast) {
    76       std::string str =
    77         "Error in NBC::make_classifier: DataLookup2D of unexpected class.";
    78       throw std::runtime_error(str);
    79     }
    80     return nbc;
    81   }
    82 
    83 
    84   void NBC::train()
     55    return new NBC();
     56  }
     57
     58
     59  void NBC::train(const MatrixLookup& data, const Target& target)
    8560  {   
    86     sigma2_.resize(data_.rows(), target_.nof_classes());
    87     centroids_.resize(data_.rows(), target_.nof_classes());
    88     utility::Matrix nof_in_class(data_.rows(), target_.nof_classes());
     61    sigma2_.resize(data.rows(), target.nof_classes());
     62    centroids_.resize(data.rows(), target.nof_classes());
     63    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
    8964   
    90     // unweighted
    91     if (data_.weighted()){
    92       const MatrixLookupWeighted& data =
    93         dynamic_cast<const MatrixLookupWeighted&>(data_);
    94       for(size_t i=0; i<data_.rows(); ++i) {
    95         std::vector<statistics::AveragerWeighted> aver(target_.nof_classes());
    96         for(size_t j=0; j<data_.columns(); ++j)
    97           aver[target_(j)].add(data.data(i,j), data.weight(i,j));
    98 
    99         assert(centroids_.columns()==target_.nof_classes());
    100         for (size_t j=0; j<target_.nof_classes(); ++j){
    101           assert(i<centroids_.rows());
    102           assert(j<centroids_.columns());
    103           assert(i<sigma2_.rows());
    104           assert(j<sigma2_.columns());
    105           if (aver[j].n()>1){
    106             sigma2_(i,j) = aver[j].variance();
    107             centroids_(i,j) = aver[j].mean();
    108           }
     65    for(size_t i=0; i<data.rows(); ++i) {
     66      std::vector<statistics::Averager> aver(target.nof_classes());
     67      for(size_t j=0; j<data.columns(); ++j)
     68        aver[target(j)].add(data(i,j));
     69     
     70      assert(centroids_.columns()==target.nof_classes());
     71      for (size_t j=0; j<target.nof_classes(); ++j){
     72        assert(i<centroids_.rows());
     73        assert(j<centroids_.columns());
     74        centroids_(i,j) = aver[j].mean();
     75        assert(i<sigma2_.rows());
     76        assert(j<sigma2_.columns());
     77        if (aver[j].n()>1){
     78          sigma2_(i,j) = aver[j].variance();
     79          centroids_(i,j) = aver[j].mean();
     80        }
    10981          else {
    11082            sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    11183            centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    11284          }
    113         }
    114       }
    115     }
    116     else { 
    117       const MatrixLookup& data = dynamic_cast<const MatrixLookup&>(data_);
    118       for(size_t i=0; i<data_.rows(); ++i) {
    119         std::vector<statistics::Averager> aver(target_.nof_classes());
    120         for(size_t j=0; j<data_.columns(); ++j)
    121           aver[target_(j)].add(data(i,j));
    122 
    123         assert(centroids_.columns()==target_.nof_classes());
    124         for (size_t j=0; j<target_.nof_classes(); ++j){
    125           assert(i<centroids_.rows());
    126           assert(j<centroids_.columns());
     85      }
     86    }
     87    trained_=true;
     88  }   
     89
     90
     91  void NBC::train(const MatrixLookupWeighted& data, const Target& target)
     92  {   
     93    sigma2_.resize(data.rows(), target.nof_classes());
     94    centroids_.resize(data.rows(), target.nof_classes());
     95    utility::Matrix nof_in_class(data.rows(), target.nof_classes());
     96
     97    for(size_t i=0; i<data.rows(); ++i) {
     98      std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
     99      for(size_t j=0; j<data.columns(); ++j)
     100        aver[target(j)].add(data.data(i,j), data.weight(i,j));
     101     
     102      assert(centroids_.columns()==target.nof_classes());
     103      for (size_t j=0; j<target.nof_classes(); ++j) {
     104        assert(i<centroids_.rows());
     105        assert(j<centroids_.columns());
     106        assert(i<sigma2_.rows());
     107        assert(j<sigma2_.columns());
     108        if (aver[j].n()>1){
     109          sigma2_(i,j) = aver[j].variance();
    127110          centroids_(i,j) = aver[j].mean();
    128           assert(i<sigma2_.rows());
    129           assert(j<sigma2_.columns());
    130           if (aver[j].n()>1){
    131             sigma2_(i,j) = aver[j].variance();
    132             centroids_(i,j) = aver[j].mean();
    133           }
    134           else {
    135             sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
    136             centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
    137           }
    138         }
    139       }
    140     }   
     111        }
     112        else {
     113          sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
     114          centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
     115        }
     116      }
     117    }
    141118    trained_=true;
    142119  }
     
    146123                    utility::Matrix& prediction) const
    147124  {   
    148     assert(data_.rows()==x.rows());
    149125    assert(x.rows()==sigma2_.rows());
    150126    assert(x.rows()==centroids_.rows());
  • trunk/yat/classifier/NBC.h

    r1152 r1157  
    5252  public:
    5353    ///
    54     /// Constructor taking the training data, the target vector.
     54    /// @brief Constructor
    5555    ///
    56     NBC(const MatrixLookup&, const Target&);
     56    NBC(void);
     57   
     58
     59    ///
     60    /// @brief Destructor
     61    ///
     62    virtual ~NBC();
     63
     64
     65    NBC* make_classifier(void) const;
    5766   
    5867    ///
    59     /// Constructor taking the training data with weights, the target
    60     /// vector, the distance measure, and a weight matrix for the
    61     /// training data as input.
    62     ///
    63     NBC(const MatrixLookupWeighted&, const Target&);
    64 
    65     virtual ~NBC();
    66 
    67     const DataLookup2D& data(void) const;
    68 
    69 
    70     NBC* make_classifier(const DataLookup2D&,
    71                          const Target&) const;
    72    
    73     ///
    74     /// Train the classifier using the training data.
     68    /// Train the classifier using training data and targets.
    7569    ///
    7670    /// For each class mean and variance are estimated for each
     
    8175    /// specific label.
    8276    ///
    83     void train();
     77    void train(const MatrixLookup&, const Target&);
     78
     79    ///
     80    /// Train the classifier using weighted training data and targets.
     81    ///
     82    void train(const MatrixLookupWeighted&, const Target&);
     83
    8484
    8585   
     
    106106    utility::Matrix centroids_;
    107107    utility::Matrix sigma2_;
    108     const DataLookup2D& data_;
    109108
    110109    double sum_logsigma(size_t i) const;
  • trunk/yat/classifier/NCC.h

    r1144 r1157  
    6565  public:
    6666    ///
    67     /// Constructor taking the training data and the target vector as
    68     /// input
    69     ///
    70     NCC(const MatrixLookup&, const Target&);
    71    
    72     ///
    73     /// Constructor taking the training data with weights and the
    74     /// target vector as input.
    75     ///
    76     NCC(const MatrixLookupWeighted&, const Target&);
    77 
    78     virtual ~NCC();
     67    /// @brief Constructor
     68    ///
     69    NCC(void);
     70   
     71
     72    ///
     73    /// @brief Destructor
     74    ///
     75    virtual ~NCC(void);
    7976
    8077    ///
     
    8380    const utility::Matrix& centroids(void) const;
    8481
    85     const DataLookup2D& data(void) const;
    86 
    87     NCC<Distance>* make_classifier(const DataLookup2D&,
    88                                    const Target&) const;
    89    
    90     ///
    91     /// Train the classifier using the training data. Centroids are
    92     /// calculated for each class.
    93     ///
    94     void train();
     82    NCC<Distance>* make_classifier(void) const;
     83   
     84    ///
     85    /// Train the classifier with a training data set and
     86    /// targets. Centroids are calculated for each class.
     87    ///
     88    void train(const MatrixLookup&, const Target&);
     89
     90
     91    ///
     92    /// Train the classifier with a weighted training data set and
     93    /// targets. Centroids are calculated for each class.
     94    ///
     95    void train(const MatrixLookupWeighted&, const Target&);
    9596
    9697   
     
    109110    bool centroids_nan_;
    110111    Distance distance_;
    111 
    112     // data_ has to be of type DataLookup2D to accomodate both
    113     // MatrixLookup and MatrixLookupWeighted
    114     const DataLookup2D& data_;
    115112  };
    116113
     
    124121
    125122  template <typename Distance>
    126   NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
    127     : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    128   {
    129   }
    130 
    131   template <typename Distance>
    132   NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
    133     : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    134   {
    135   }
     123  NCC<Distance>::NCC()
     124    : SupervisedClassifier(), centroids_(0), centroids_nan_(false)
     125  {
     126  }
     127
    136128
    137129  template <typename Distance>
     
    142134  }
    143135
     136
    144137  template <typename Distance>
    145138  const utility::Matrix& NCC<Distance>::centroids(void) const
     
    150143
    151144  template <typename Distance>
    152   const DataLookup2D& NCC<Distance>::data(void) const
    153   {
    154     return data_;
    155   }
    156  
    157   template <typename Distance>
    158145  NCC<Distance>*
    159   NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const
     146  NCC<Distance>::make_classifier() const
    160147  {     
    161     NCC* ncc=0;
    162     try {
    163       if(data.weighted()) {
    164         ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
    165                               target);
    166       }
    167       else {
    168         ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
    169                               target);
    170       }
    171     }
    172     catch (std::bad_cast) {
    173       std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
    174       throw std::runtime_error(str);
    175     }
    176     return ncc;
    177   }
    178 
    179 
    180   template <typename Distance>
    181   void NCC<Distance>::train()
     148    return new NCC<Distance>();
     149  }
     150
     151  template <typename Distance>
     152  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
    182153  {   
    183154    if(centroids_)
    184155      delete centroids_;
    185     centroids_= new utility::Matrix(data_.rows(), target_.nof_classes());
    186     // data_ is a MatrixLookup or a MatrixLookupWeighted
    187     if(data_.weighted()) {
    188       const MatrixLookupWeighted* weighted_data =
    189         dynamic_cast<const MatrixLookupWeighted*>(&data_);     
    190       for(size_t i=0; i<data_.rows(); i++) {
    191         std::vector<statistics::AveragerWeighted> class_averager;
    192         class_averager.resize(target_.nof_classes());
    193         for(size_t j=0; j<data_.columns(); j++) {
    194           class_averager[target_(j)].add(weighted_data->data(i,j),
    195                                          weighted_data->weight(i,j));
     156    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
     157    for(size_t i=0; i<data.rows(); i++) {
     158      std::vector<statistics::Averager> class_averager;
     159      class_averager.resize(target.nof_classes());
     160      for(size_t j=0; j<data.columns(); j++) {
     161        class_averager[target(j)].add(data(i,j));
     162      }
     163      for(size_t c=0;c<target.nof_classes();c++) {
     164        (*centroids_)(i,c) = class_averager[c].mean();
     165      }
     166    }
     167    trained_=true;
     168  }
     169
     170
     171  template <typename Distance>
     172  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
     173  {   
     174    if(centroids_)
     175      delete centroids_;
     176    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
     177    for(size_t i=0; i<data.rows(); i++) {
     178      std::vector<statistics::AveragerWeighted> class_averager;
     179      class_averager.resize(target.nof_classes());
     180      for(size_t j=0; j<data.columns(); j++)
     181        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
     182      for(size_t c=0;c<target.nof_classes();c++) {
     183        if(class_averager[c].sum_w()==0) {
     184          centroids_nan_=true;
    196185        }
    197         for(size_t c=0;c<target_.nof_classes();c++) {
    198           if(class_averager[c].sum_w()==0) {
    199             centroids_nan_=true;
    200           }
    201           (*centroids_)(i,c) = class_averager[c].mean();
    202         }
    203       }
    204     }
    205     else {
    206       const MatrixLookup* unweighted_data =
    207         dynamic_cast<const MatrixLookup*>(&data_);     
    208       for(size_t i=0; i<data_.rows(); i++) {
    209         std::vector<statistics::Averager> class_averager;
    210         class_averager.resize(target_.nof_classes());
    211         for(size_t j=0; j<data_.columns(); j++) {
    212           class_averager[target_(j)].add((*unweighted_data)(i,j));
    213         }
    214         for(size_t c=0;c<target_.nof_classes();c++) {
    215           (*centroids_)(i,c) = class_averager[c].mean();
    216         }
    217       }
    218     }
    219   }
     186        (*centroids_)(i,c) = class_averager[c].mean();
     187      }
     188    }
     189    trained_=true;
     190  }
     191
    220192
    221193  template <typename Distance>
     
    226198      (centroids_,"NCC::predict called for untrained classifier");
    227199    utility::yat_assert<std::runtime_error>
    228       (data_.rows()==test.rows(),
     200      (centroids_->rows()==test.rows(),
    229201       "NCC::predict test data with incorrect number of rows");
    230202   
  • trunk/yat/classifier/SupervisedClassifier.cc

    r1000 r1157  
    3030namespace classifier {
    3131
    32   SupervisedClassifier::SupervisedClassifier(const Target& target)
    33     : target_(target), trained_(false)
     32  SupervisedClassifier::SupervisedClassifier()
     33    : trained_(false)
    3434  {
    3535  }
  • trunk/yat/classifier/SupervisedClassifier.h

    r1125 r1157  
    3939
    4040  class DataLookup2D;
     41  class MatrixLookup;
     42  class MatrixLookupWeighted;
    4143  class Target;
    4244
     
    5052  public:
    5153    ///
    52     /// @brief Constructor taking a Target object.
     54    /// @brief Constructor
    5355    ///
    54     SupervisedClassifier(const Target&);
     56    SupervisedClassifier(void);
    5557   
    5658
     
    5961    ///
    6062    virtual ~SupervisedClassifier(void);
    61 
    62     ///
    63     /// @brief Access to the training data
    64     ///
    65     virtual const DataLookup2D& data(void) const =0;
    6663
    6764
     
    7572    ///
    7673    virtual SupervisedClassifier*
    77     make_classifier(const DataLookup2D&, const Target&) const =0;
     74    make_classifier() const =0;
    7875   
    7976
     
    8885    /// Train the classifier.
    8986    ///
    90     virtual void train()=0;
     87    virtual void train(const MatrixLookup&, const Target&)=0;
     88
     89    ///
     90    /// Train the classifier.
     91    ///
     92    virtual void train(const MatrixLookupWeighted&, const Target&)=0;
    9193
    9294   
    9395  protected:
    9496   
    95     /// Target to train on.
    96     const Target& target_;
    9797    /// true if classifier successfully trained
    98     bool trained_;
    99    
     98    bool trained_;   
    10099   
    101100  }; 
Note: See TracChangeset for help on using the changeset viewer.