Changeset 634


Ignore:
Timestamp:
Sep 5, 2006, 3:50:03 PM (15 years ago)
Author:
Markus Ringnér
Message:

Fixed NCC predict to work with MatrixLookupWeighted?

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/c++_tools/classifier/NCC.cc

    r632 r634  
    4646      ncc=new NCC(dynamic_cast<const MatrixLookupWeighted&>(cs.training_data()),
    4747                  cs.training_target(),this->distance_);
    48                   }
     48    }
    4949    else {
    5050      ncc=new NCC(dynamic_cast<const MatrixLookup&>(cs.training_data()),
     
    5656
    5757  bool NCC::train()
    58   {
    59     // Calculate centroids based on
    60     // all inputs ( = all rows in data matrix).
     58  {   
    6159    centroids_=utility::matrix(data_.rows(), target_.nof_classes());
    6260    utility::matrix nof_in_class(data_.rows(), target_.nof_classes());
    6361    for(size_t i=0; i<data_.rows(); i++) {
    6462      for(size_t j=0; j<data_.columns(); j++) {
    65         double weight=1.0;
    66         if(data_.weighted())
    67           weight=dynamic_cast<const MatrixLookupWeighted&>(data_).weight(i,j);
    6863        centroids_(i,target_(j)) += data_(i,j);
    69         nof_in_class(i,target_(j))+=weight;
     64        try {
     65          nof_in_class(i,target_(j))+=
     66            dynamic_cast<const MatrixLookupWeighted&>(data_).weight(i,j);
     67        }
     68        catch (std::bad_cast) {
     69          nof_in_class(i,target_(j))+=1.0;
     70        }
    7071      }
    71     }
    72    
     72    }   
    7373    centroids_.div_elements(nof_in_class);
    7474    trained_=true;
     
    7777
    7878
    79   void NCC::predict(const DataLookup1D& input,
     79  void NCC::predict(const DataLookup1D& input, const utility::vector& weights,
    8080                    utility::vector& prediction) const
    8181  {
    8282    prediction=utility::vector(centroids_.columns());   
    83     utility::vector w(input.size(),0);
     83
    8484    utility::vector value(input.size(),0);
    85     for(size_t i=0; i<input.size(); i++)  { // take care of missing values
     85    for(size_t i=0; i<input.size(); i++)
    8686      value(i)=input(i);
    87       if(!std::isnan(value(i)))
    88         w(i)=1.0;
    89     }
     87   
     88    // take care of nan's in centroids
    9089    for(size_t j=0; j<centroids_.columns(); j++) {
    9190      utility::vector centroid=utility::vector(centroids_,j,false);
    9291      utility::vector wc(centroid.size(),0);
    93       for(size_t i=0; i<centroid.size(); i++)  { // take care of missing values
     92      for(size_t i=0; i<centroid.size(); i++)  {
    9493        if(!std::isnan(centroid(i)))
    9594          wc(i)=1.0;
    9695      }
    97       prediction(j)=distance_(value,centroid,w,wc);   
     96      prediction(j)=distance_(value,centroid,weights,wc);   
    9897    }
    9998  }
     
    102101  void NCC::predict(const DataLookup2D& input,                   
    103102                    utility::matrix& prediction) const
    104   {
     103  {   
    105104    prediction=utility::matrix(centroids_.columns(), input.columns());   
    106     for(size_t j=0; j<input.columns();j++) {     
    107       DataLookup1D in(input,j,false);
    108       utility::vector out;
    109       predict(in,out);
    110       prediction.set_column(j,out);
     105    try {   
     106      const MatrixLookupWeighted& data=
     107        dynamic_cast<const MatrixLookupWeighted&>(input);     
     108      for(size_t j=0; j<input.columns();j++) {     
     109        DataLookup1D in(input,j,false);
     110        utility::vector weights(in.size(),0);
     111        for(size_t i=0; i<in.size();i++)
     112          weights(i)=data.weight(i,j);
     113        utility::vector out;
     114        predict(in,weights,out);
     115        prediction.set_column(j,out);
     116      }
     117    }
     118    catch (std::bad_cast) {
     119      try {
     120        dynamic_cast<const MatrixLookup&>(input);
     121        for(size_t j=0; j<input.columns();j++) {     
     122          DataLookup1D in(input,j,false);
     123          utility::vector weights(in.size(),1.0);
     124          utility::vector out;
     125          predict(in,weights,out);
     126          prediction.set_column(j,out);
     127        }
     128      }
     129      catch (std::bad_cast) {
     130        std::cerr << "Error in NCC::predict: DataLookup2D of unexpected class"
     131                  << std::endl;
     132      }
    111133    }
    112134  }
  • trunk/c++_tools/classifier/NCC.h

    r632 r634  
    6767    bool train();
    6868
    69 
    70     ///
    71     /// Calculate the distance to each centroid for a test sample
    72     ///
    73     void predict(const DataLookup1D&, utility::vector&) const;
    7469   
    7570    ///
     
    8378    const statistics::Distance& distance_;                 
    8479    const DataLookup2D& data_;
     80
     81    ///
     82    /// Calculate the distance to each centroid for a test sample
     83    ///
     84    void predict(const DataLookup1D&, const utility::vector&,
     85                 utility::vector&) const;
     86
    8587  };
    8688
  • trunk/test/ncc_test.cc

    r632 r634  
    8383  }
    8484
    85 
    86   classifier:: MatrixLookup dataview(data);
    8785  utility::matrix prediction;
    88   ncc.predict(dataview,prediction);
     86  ncc.predict(dataviewweighted,prediction);
    8987 
    9088  // Comparing the prediction to stored result
     
    10098  }
    10199  slack /= (result.columns()*result.rows());
    102   if (slack > slack_bound){
     100  if (slack > slack_bound || std::isnan(slack)){
    103101    *error << "Difference to stored prediction too large\n";
    104102    *error << "slack: " << slack << std::endl;
     
    109107
    110108  // Testing IGP 
     109  classifier:: MatrixLookup dataview(data);
    111110  *error << "testing igp" << std::endl;
    112111  classifier::IGP igp(dataview,targets,pearson);
Note: See TracChangeset for help on using the changeset viewer.