source: trunk/test/ensemble_test.cc @ 1227

Last change on this file since 1227 was 1227, checked in by Peter, 15 years ago

fixes #341 and #93

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.9 KB
Line 
1// $Id: ensemble_test.cc 1227 2008-03-13 02:43:41Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Markus Ringnér, Peter Johansson
5  Copyright (C) 2007 Peter Johansson
6
7  This file is part of the yat library, http://trac.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 2 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with this program; if not, write to the Free Software
21  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
22  02111-1307, USA.
23*/
24
25#include "yat/utility/Matrix.h"
26#include "yat/classifier/SubsetGenerator.h"
27#include "yat/classifier/CrossValidationSampler.h"
28#include "yat/classifier/EnsembleBuilder.h"
29#include "yat/classifier/Kernel.h"
30#include "yat/classifier/KernelLookup.h"
31#include "yat/classifier/Kernel_SEV.h"
32#include "yat/classifier/Kernel_MEV.h"
33#include "yat/classifier/MatrixLookup.h"
34#include "yat/classifier/MatrixLookupWeighted.h"
35#include "yat/classifier/NCC.h"
36#include "yat/classifier/PolynomialKernelFunction.h"
37#include "yat/classifier/SVM.h"
38#include "yat/statistics/AUC.h"
39#include "yat/statistics/EuclideanDistance.h"
40
41#include <cassert>
42#include <fstream>
43#include <iostream>
44#include <cstdlib>
45#include <limits>
46
47
48int main(const int argc,const char* argv[])
49{ 
50  using namespace theplu::yat;
51
52  std::ostream* error;
53  if (argc>1 && argv[1]==std::string("-v"))
54    error = &std::cerr;
55  else {
56    error = new std::ofstream("/dev/null");
57    if (argc>1)
58      std::cout << "ensemble_test -v : for printing extra information\n";
59  }
60  *error << "testing ensemble" << std::endl;
61  bool ok = true;
62
63  *error << "loading data" << std::endl;
64  std::ifstream is("data/nm_data_centralized.txt");
65  utility::Matrix data_core(is);
66  is.close();
67
68  *error << "create MatrixLookup" << std::endl;
69  classifier::MatrixLookup data(data_core);
70  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
71  *error << "Building kernel" << std::endl;
72  classifier::Kernel_SEV kernel(data,*kf);
73
74
75  *error << "load target" << std::endl;
76  is.open("data/nm_target_bin.txt");
77  classifier::Target target(is);
78  is.close();
79  assert(data.columns()==target.size());
80
81  {
82    *error << "create ensemble of ncc" << std::endl;
83    classifier::NCC<statistics::EuclideanDistance> ncc;
84    classifier::CrossValidationSampler sampler(target,3,3);
85    classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
86    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
87      classifier::MatrixLookup> ensemble(ncc, data, sampler);
88    *error << "build ensemble" << std::endl;
89    ensemble.build();
90    std::vector<std::vector<statistics::Averager> > result;
91    ensemble.predict(data, result);
92  }
93
94  {
95    *error << "create ensemble of ncc" << std::endl;
96    classifier::MatrixLookupWeighted data_weighted(data);
97    classifier::NCC<statistics::EuclideanDistance> ncc;
98    classifier::CrossValidationSampler sampler(target,3,3);
99    classifier::SubsetGenerator<classifier::MatrixLookupWeighted> 
100      subdata(sampler,data_weighted);
101    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
102      classifier::MatrixLookupWeighted> ensemble(ncc, data_weighted, sampler);
103    *error << "build ensemble" << std::endl;
104    ensemble.build();
105    std::vector<std::vector<statistics::Averager> > result;
106    ensemble.predict(data_weighted, result);
107  }
108
109  *error << "create KernelLookup" << std::endl;
110  classifier::KernelLookup kernel_lookup(kernel);
111  *error << "create svm" << std::endl;
112  classifier::SVM svm;
113  *error << "create Subsets" << std::endl;
114  classifier::CrossValidationSampler sampler(target,3,3);
115  classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
116                                                           kernel_lookup);
117
118  *error << "create ensemble" << std::endl;
119  classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
120    ensemble(svm, kernel_lookup, sampler);
121  *error << "build ensemble" << std::endl;
122  ensemble.build();
123  std::vector<std::vector<statistics::Averager> > result;
124  ensemble.predict(kernel_lookup, result);
125 
126  utility::Vector out(target.size(),0);
127  for (size_t i = 0; i<out.size(); ++i)
128    out(i)=ensemble.validate()[0][i].mean(); 
129  statistics::AUC roc;
130  *error << roc.score(target,out) << std::endl;
131
132  {
133    *error << "create ensemble" << std::endl;
134    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
135      ensemble(svm, kernel_lookup, sampler);
136    *error << "test validate() before build()\n";
137    ensemble.validate();
138    std::vector<std::vector<statistics::Averager> > result;
139    *error << "test predict() before build()\n";
140    ensemble.predict(kernel_lookup, result);
141  }
142  delete kf;
143
144  if (error!=&std::cerr)
145    delete error;
146
147  if(ok)
148    return 0;
149  return -1;
150 
151}
Note: See TracBrowser for help on using the repository browser.