source: branches/0.5-stable/test/ensemble_test.cc @ 2135

Last change on this file since 2135 was 2135, checked in by Peter, 13 years ago

updating copyright statements

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 5.7 KB
Line 
1// $Id: ensemble_test.cc 2135 2009-12-24 12:41:49Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
5  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
7  Copyright (C) 2009 Peter Johansson
8
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23*/
24
25#include "Suite.h"
26
27#include "yat/utility/Matrix.h"
28#include "yat/classifier/SubsetGenerator.h"
29#include "yat/classifier/CrossValidationSampler.h"
30#include "yat/classifier/EnsembleBuilder.h"
31#include "yat/classifier/Kernel.h"
32#include "yat/classifier/KernelLookup.h"
33#include "yat/classifier/Kernel_SEV.h"
34#include "yat/classifier/Kernel_MEV.h"
35#include "yat/classifier/MatrixLookup.h"
36#include "yat/classifier/MatrixLookupWeighted.h"
37#include "yat/classifier/NCC.h"
38#include "yat/classifier/PolynomialKernelFunction.h"
39#include "yat/classifier/SVM.h"
40#include "yat/statistics/AUC.h"
41#include "yat/statistics/EuclideanDistance.h"
42
43#include <cassert>
44#include <fstream>
45#include <iostream>
46#include <cstdlib>
47#include <limits>
48
49
50int main(int argc, char* argv[])
51{ 
52  using namespace theplu::yat;
53  test::Suite suite(argc, argv);
54 
55  suite.err() << "testing ensemble" << std::endl;
56
57  suite.err() << "loading data" << std::endl;
58  std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
59  utility::Matrix data_core(is);
60  is.close();
61
62  suite.err() << "create MatrixLookup" << std::endl;
63  classifier::MatrixLookup data(data_core);
64  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
65  suite.err() << "Building kernel" << std::endl;
66  classifier::Kernel_SEV kernel(data,*kf);
67
68
69  suite.err() << "load target" << std::endl;
70  is.open(test::filename("data/nm_target_bin.txt").c_str());
71  classifier::Target target(is);
72  is.close();
73  assert(data.columns()==target.size());
74
75  {
76    suite.err() << "create ensemble of ncc" << std::endl;
77    classifier::NCC<statistics::EuclideanDistance> ncc;
78    classifier::CrossValidationSampler sampler(target,3,3);
79    classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
80    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
81      classifier::MatrixLookup> ensemble(ncc, data, sampler);
82    suite.err() << "build ensemble" << std::endl;
83    ensemble.build();
84    std::vector<std::vector<statistics::Averager> > result;
85    ensemble.predict(data, result);
86  }
87
88  {
89    suite.err() << "create ensemble of ncc" << std::endl;
90    classifier::MatrixLookupWeighted data_weighted(data);
91    classifier::NCC<statistics::EuclideanDistance> ncc;
92    classifier::CrossValidationSampler sampler(target,3,3);
93    classifier::SubsetGenerator<classifier::MatrixLookupWeighted> 
94      subdata(sampler,data_weighted);
95    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
96      classifier::MatrixLookupWeighted> ensemble(ncc, data_weighted, sampler);
97    suite.err() << "build ensemble" << std::endl;
98    ensemble.build();
99    std::vector<std::vector<statistics::Averager> > result;
100    ensemble.predict(data_weighted, result);
101  }
102
103  suite.err() << "create KernelLookup" << std::endl;
104  classifier::KernelLookup kernel_lookup(kernel);
105  suite.err() << "create svm" << std::endl;
106  classifier::SVM svm;
107  suite.err() << "create Subsets" << std::endl;
108  classifier::CrossValidationSampler sampler(target,3,3);
109  classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
110                                                           kernel_lookup);
111
112  suite.err() << "create ensemble" << std::endl;
113  classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
114    ensemble(svm, kernel_lookup, sampler);
115  suite.err() << "build ensemble" << std::endl;
116  ensemble.build();
117  utility::Vector out(target.size(),0);
118  for (size_t i = 0; i<out.size(); ++i) {
119    out(i)=ensemble.validate()[0][i].mean(); 
120  }
121  statistics::AUC roc;
122  suite.err() << roc.score(target,out) << std::endl;
123
124  std::vector<std::vector<statistics::Averager> > result;
125  ensemble.predict(kernel_lookup, result);
126  for (size_t i = 0; i<result.size(); ++i) {
127    for (size_t j=0; j<result[0].size(); ++j) {
128      if (!suite.add(result[i][j].variance() > 0)) {
129        suite.err() << "error: element " << i << " " << j << "\n";
130        suite.err() << "expected finite prediction varince\n";
131        suite.err() << "found: " << result[i][j].variance() << "\n";
132      }
133    } 
134  }
135
136  {
137    suite.err() << "test ensemble of SVMs with weighted kernel" << std::endl;
138    classifier::MatrixLookupWeighted wdata(data_core);
139    classifier::Kernel_SEV kernel(wdata, *kf);
140    classifier::KernelLookup wkl(kernel);
141    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
142      ensemble(svm, wkl, sampler);
143    suite.err() << "build ensemble" << std::endl;
144    ensemble.build();
145    ensemble.validate(); 
146    std::vector<std::vector<statistics::Averager> > result;
147    ensemble.predict(wkl, result);
148  }
149
150  {
151    suite.err() << "create ensemble" << std::endl;
152    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
153      ensemble(svm, kernel_lookup, sampler);
154    suite.err() << "test validate() before build()\n";
155    ensemble.validate();
156    std::vector<std::vector<statistics::Averager> > result;
157    suite.err() << "test predict() before build()\n";
158    ensemble.predict(kernel_lookup, result);
159  }
160  delete kf;
161
162  return suite.return_value();
163}
Note: See TracBrowser for help on using the repository browser.