Changeset 930 for trunk/yat


Ignore:
Timestamp:
Oct 4, 2007, 3:30:02 PM (14 years ago)
Author:
Markus Ringnér
Message:

Fixed support for MatrixLookup? in NCC. See ticket:259

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/yat/classifier/NCC.h

    r925 r930  
    105105  private:
    106106
    107     utility::matrix centroids_;
     107    utility::matrix* centroids_;
    108108
    109109    // data_ has to be of type DataLookup2D to accomodate both
     
    123123  template <typename Distance>
    124124  NCC<Distance>::NCC(const MatrixLookup& data, const Target& target)
    125     : SupervisedClassifier(target), data_(data)
     125    : SupervisedClassifier(target), centroids_(0), data_(data)
    126126  {
    127127  }
     
    129129  template <typename Distance>
    130130  NCC<Distance>::NCC(const MatrixLookupWeighted& data, const Target& target)
    131     : SupervisedClassifier(target), data_(data)
     131    : SupervisedClassifier(target), centroids_(0), data_(data)
    132132  {
    133133  }
     
    136136  NCC<Distance>::~NCC()   
    137137  {
    138   }
    139 
     138    if(centroids_)
     139      delete centroids_;
     140  }
    140141
    141142  template <typename Distance>
    142143  const utility::matrix& NCC<Distance>::centroids(void) const
    143144  {
    144     return centroids_;
     145    return *centroids_;
    145146  }
    146147 
     
    158159    NCC* ncc=0;
    159160    if(data.weighted()) {
    160       ncc=new NCC<Distance>(dynamic_cast<const MatrixLookupWeighted&>(data),
     161      ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookupWeighted*>(&data),
    161162                  target);
    162163    }
    163164    else {
    164       ncc=new NCC<Distance>(dynamic_cast<const MatrixLookup&>(data),
     165      ncc=new NCC<Distance>(*dynamic_cast<const MatrixLookup*>(&data),
    165166                  target);
    166167    }
     168    ncc->centroids_=0;
    167169    return ncc;
    168170  }
     
    172174  bool NCC<Distance>::train()
    173175  {   
    174     centroids_.clone(utility::matrix(data_.rows(), target_.nof_classes()));
     176    if(centroids_)
     177      delete centroids_;
     178    centroids_= new utility::matrix(data_.rows(), target_.nof_classes());
    175179    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
    176180    const MatrixLookupWeighted* weighted_data =
     
    180184    for(size_t i=0; i<data_.rows(); i++) {
    181185      for(size_t j=0; j<data_.columns(); j++) {
    182         centroids_(i,target_(j)) += data_(i,j);
     186        (*centroids_)(i,target_(j)) += data_(i,j);
    183187        if (weighted)
    184188          nof_in_class(i,target_(j))+= weighted_data->weight(i,j);
     
    187191      }
    188192    }   
    189     centroids_.div(nof_in_class);
     193    centroids_->div(nof_in_class);
    190194    trained_=true;
    191195    return trained_;
     
    194198  template <typename Distance>
    195199  void NCC<Distance>::predict(const DataLookup2D& input,                   
    196                     utility::matrix& prediction) const
     200                              utility::matrix& prediction) const
    197201  {   
    198     prediction.clone(utility::matrix(centroids_.columns(), input.columns()));   
    199 
     202    prediction.clone(utility::matrix(centroids_->columns(), input.columns()));   
     203   
    200204    // Weighted case
    201205    const MatrixLookupWeighted* testdata =
    202206      dynamic_cast<const MatrixLookupWeighted*>(&input);     
    203207    if (testdata) {
    204       MatrixLookupWeighted weighted_centroids(centroids_);
     208      MatrixLookupWeighted weighted_centroids(*centroids_);
    205209      for(size_t j=0; j<input.columns();j++) {       
    206210        DataLookupWeighted1D in(*testdata,j,false);
    207         for(size_t k=0; k<centroids_.columns();k++) {
     211        for(size_t k=0; k<centroids_->columns();k++) {
    208212          DataLookupWeighted1D centroid(weighted_centroids,k,false);
    209 
    210213          yat_assert(in.size()==centroid.size());
    211214          prediction(k,j)=statistics::
    212215            vector_distance(in.begin(),in.end(),centroid.begin(),
    213                              typename statistics::vector_distance_traits<Distance>::distance());
     216                            typename statistics::vector_distance_traits<Distance>::distance());
    214217        }
    215218      }
    216219    }
     220    // Non-weighted case
    217221    else {
    218       std::string str;
    219       str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
    220       throw std::runtime_error(str);
     222      const MatrixLookup* testdata =
     223        dynamic_cast<const MatrixLookup*>(&input);     
     224      if (testdata) {
     225        MatrixLookup unweighted_centroids(*centroids_);
     226        for(size_t j=0; j<input.columns();j++) {       
     227          DataLookup1D in(*testdata,j,false);
     228          for(size_t k=0; k<centroids_->columns();k++) {
     229            DataLookup1D centroid(unweighted_centroids,k,false);           
     230            yat_assert(in.size()==centroid.size());
     231            prediction(k,j)=statistics::
     232              vector_distance(in.begin(),in.end(),centroid.begin(),
     233                              typename statistics::vector_distance_traits<Distance>::distance());
     234          }
     235        }
     236      }     
     237      else {
     238        std::string str;
     239        str = "Error in NCC<Distance>::predict: DataLookup2D of unexpected class.";
     240        throw std::runtime_error(str);
     241      }
    221242    }
    222243  }
Note: See TracChangeset for help on using the changeset viewer.