Ignore:
Timestamp:
Feb 26, 2008, 2:25:19 PM (14 years ago)
Author:
Markus Ringnér
Message:

Refs #318

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.h

    r1144 r1157  
    6565  public:
    6666    ///
    67     /// Constructor taking the training data and the target vector as
    68     /// input
    69     ///
    70     NCC(const MatrixLookup&, const Target&);
    71    
    72     ///
    73     /// Constructor taking the training data with weights and the
    74     /// target vector as input.
    75     ///
    76     NCC(const MatrixLookupWeighted&, const Target&);
    77 
    78     virtual ~NCC();
     67    /// @brief Constructor
     68    ///
     69    NCC(void);
     70   
     71
     72    ///
     73    /// @brief Destructor
     74    ///
     75    virtual ~NCC(void);
    7976
    8077    ///
     
    8380    const utility::Matrix& centroids(void) const;
    8481
    85     const DataLookup2D& data(void) const;
    86 
    87     NCC<Distance>* make_classifier(const DataLookup2D&,
    88                                    const Target&) const;
    89    
    90     ///
    91     /// Train the classifier using the training data. Centroids are
    92     /// calculated for each class.
    93     ///
    94     void train();
     82    NCC<Distance>* make_classifier(void) const;
     83   
     84    ///
     85    /// Train the classifier with a training data set and
     86    /// targets. Centroids are calculated for each class.
     87    ///
     88    void train(const MatrixLookup&, const Target&);
     89
     90
     91    ///
     92    /// Train the classifier with a weighted training data set and
     93    /// targets. Centroids are calculated for each class.
     94    ///
     95    void train(const MatrixLookupWeighted&, const Target&);
    9596
    9697   
     
    109110    bool centroids_nan_;
    110111    Distance distance_;
    111 
    112     // data_ has to be of type DataLookup2D to accomodate both
    113     // MatrixLookup and MatrixLookupWeighted
    114     const DataLookup2D& data_;
    115112  };
    116113
     
    124121
    125122  template <typename Distance>
    126   NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
    127     : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    128   {
    129   }
    130 
    131   template <typename Distance>
    132   NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
    133     : SupervisedClassifier(target), centroids_(0), centroids_nan_(false), data_(data)
    134   {
    135   }
     123  NCC<Distance>::NCC()
     124    : SupervisedClassifier(), centroids_(0), centroids_nan_(false)
     125  {
     126  }
     127
    136128
    137129  template <typename Distance>
     
    142134  }
    143135
     136
    144137  template <typename Distance>
    145138  const utility::Matrix& NCC<Distance>::centroids(void) const
     
    150143
    151144  template <typename Distance>
    152   const DataLookup2D& NCC<Distance>::data(void) const
    153   {
    154     return data_;
    155   }
    156  
    157   template <typename Distance>
    158145  NCC<Distance>*
    159   NCC<Distance>::make_classifier(const DataLookup2D& data, const Target& target) const
     146  NCC<Distance>::make_classifier() const
    160147  {     
    161     NCC* ncc=0;
    162     try {
    163       if(data.weighted()) {
    164         ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
    165                               target);
    166       }
    167       else {
    168         ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
    169                               target);
    170       }
    171     }
    172     catch (std::bad_cast) {
    173       std::string str = "Error in NCC<Distance>::make_classifier: DataLookup2D of unexpected class.";
    174       throw std::runtime_error(str);
    175     }
    176     return ncc;
    177   }
    178 
    179 
    180   template <typename Distance>
    181   void NCC<Distance>::train()
     148    return new NCC<Distance>();
     149  }
     150
     151  template <typename Distance>
     152  void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
    182153  {   
    183154    if(centroids_)
    184155      delete centroids_;
    185     centroids_= new utility::Matrix(data_.rows(), target_.nof_classes());
    186     // data_ is a MatrixLookup or a MatrixLookupWeighted
    187     if(data_.weighted()) {
    188       const MatrixLookupWeighted* weighted_data =
    189         dynamic_cast<const MatrixLookupWeighted*>(&data_);     
    190       for(size_t i=0; i<data_.rows(); i++) {
    191         std::vector<statistics::AveragerWeighted> class_averager;
    192         class_averager.resize(target_.nof_classes());
    193         for(size_t j=0; j<data_.columns(); j++) {
    194           class_averager[target_(j)].add(weighted_data->data(i,j),
    195                                          weighted_data->weight(i,j));
     156    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
     157    for(size_t i=0; i<data.rows(); i++) {
     158      std::vector<statistics::Averager> class_averager;
     159      class_averager.resize(target.nof_classes());
     160      for(size_t j=0; j<data.columns(); j++) {
     161        class_averager[target(j)].add(data(i,j));
     162      }
     163      for(size_t c=0;c<target.nof_classes();c++) {
     164        (*centroids_)(i,c) = class_averager[c].mean();
     165      }
     166    }
     167    trained_=true;
     168  }
     169
     170
     171  template <typename Distance>
     172  void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
     173  {   
     174    if(centroids_)
     175      delete centroids_;
     176    centroids_= new utility::Matrix(data.rows(), target.nof_classes());
     177    for(size_t i=0; i<data.rows(); i++) {
     178      std::vector<statistics::AveragerWeighted> class_averager;
     179      class_averager.resize(target.nof_classes());
     180      for(size_t j=0; j<data.columns(); j++)
     181        class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
     182      for(size_t c=0;c<target.nof_classes();c++) {
     183        if(class_averager[c].sum_w()==0) {
     184          centroids_nan_=true;
    196185        }
    197         for(size_t c=0;c<target_.nof_classes();c++) {
    198           if(class_averager[c].sum_w()==0) {
    199             centroids_nan_=true;
    200           }
    201           (*centroids_)(i,c) = class_averager[c].mean();
    202         }
    203       }
    204     }
    205     else {
    206       const MatrixLookup* unweighted_data =
    207         dynamic_cast<const MatrixLookup*>(&data_);     
    208       for(size_t i=0; i<data_.rows(); i++) {
    209         std::vector<statistics::Averager> class_averager;
    210         class_averager.resize(target_.nof_classes());
    211         for(size_t j=0; j<data_.columns(); j++) {
    212           class_averager[target_(j)].add((*unweighted_data)(i,j));
    213         }
    214         for(size_t c=0;c<target_.nof_classes();c++) {
    215           (*centroids_)(i,c) = class_averager[c].mean();
    216         }
    217       }
    218     }
    219   }
     186        (*centroids_)(i,c) = class_averager[c].mean();
     187      }
     188    }
     189    trained_=true;
     190  }
     191
    220192
    221193  template <typename Distance>
     
    226198      (centroids_,"NCC::predict called for untrained classifier");
    227199    utility::yat_assert<std::runtime_error>
    228       (data_.rows()==test.rows(),
     200      (centroids_->rows()==test.rows(),
    229201       "NCC::predict test data with incorrect number of rows");
    230202   
Note: See TracChangeset for help on using the changeset viewer.