source: trunk/test/ensemble_test.cc @ 1658

Last change on this file since 1658 was 1487, checked in by Jari Häkkinen, 13 years ago

Addresses #436. GPL license copy reference should also be updated.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.7 KB
Line 
1// $Id: ensemble_test.cc 1487 2008-09-10 08:41:36Z jari $
2
3/*
4  Copyright (C) 2006, 2007 Jari Häkkinen, Peter Johansson, Markus Ringnér
5  Copyright (C) 2008 Peter Johansson, Markus Ringnér
6
7  This file is part of the yat library, http://dev.thep.lu.se/yat
8
9  The yat library is free software; you can redistribute it and/or
10  modify it under the terms of the GNU General Public License as
11  published by the Free Software Foundation; either version 3 of the
12  License, or (at your option) any later version.
13
14  The yat library is distributed in the hope that it will be useful,
15  but WITHOUT ANY WARRANTY; without even the implied warranty of
16  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  General Public License for more details.
18
19  You should have received a copy of the GNU General Public License
20  along with yat. If not, see <http://www.gnu.org/licenses/>.
21*/
22
23#include "Suite.h"
24
25#include "yat/utility/Matrix.h"
26#include "yat/classifier/SubsetGenerator.h"
27#include "yat/classifier/CrossValidationSampler.h"
28#include "yat/classifier/EnsembleBuilder.h"
29#include "yat/classifier/Kernel.h"
30#include "yat/classifier/KernelLookup.h"
31#include "yat/classifier/Kernel_SEV.h"
32#include "yat/classifier/Kernel_MEV.h"
33#include "yat/classifier/MatrixLookup.h"
34#include "yat/classifier/MatrixLookupWeighted.h"
35#include "yat/classifier/NCC.h"
36#include "yat/classifier/PolynomialKernelFunction.h"
37#include "yat/classifier/SVM.h"
38#include "yat/statistics/AUC.h"
39#include "yat/statistics/EuclideanDistance.h"
40
41#include <cassert>
42#include <fstream>
43#include <iostream>
44#include <cstdlib>
45#include <limits>
46
47
48int main(int argc, char* argv[])
49{ 
50  using namespace theplu::yat;
51  test::Suite suite(argc, argv);
52 
53  suite.err() << "testing ensemble" << std::endl;
54
55  suite.err() << "loading data" << std::endl;
56  std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
57  utility::Matrix data_core(is);
58  is.close();
59
60  suite.err() << "create MatrixLookup" << std::endl;
61  classifier::MatrixLookup data(data_core);
62  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
63  suite.err() << "Building kernel" << std::endl;
64  classifier::Kernel_SEV kernel(data,*kf);
65
66
67  suite.err() << "load target" << std::endl;
68  is.open(test::filename("data/nm_target_bin.txt").c_str());
69  classifier::Target target(is);
70  is.close();
71  assert(data.columns()==target.size());
72
73  {
74    suite.err() << "create ensemble of ncc" << std::endl;
75    classifier::NCC<statistics::EuclideanDistance> ncc;
76    classifier::CrossValidationSampler sampler(target,3,3);
77    classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
78    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
79      classifier::MatrixLookup> ensemble(ncc, data, sampler);
80    suite.err() << "build ensemble" << std::endl;
81    ensemble.build();
82    std::vector<std::vector<statistics::Averager> > result;
83    ensemble.predict(data, result);
84  }
85
86  {
87    suite.err() << "create ensemble of ncc" << std::endl;
88    classifier::MatrixLookupWeighted data_weighted(data);
89    classifier::NCC<statistics::EuclideanDistance> ncc;
90    classifier::CrossValidationSampler sampler(target,3,3);
91    classifier::SubsetGenerator<classifier::MatrixLookupWeighted> 
92      subdata(sampler,data_weighted);
93    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
94      classifier::MatrixLookupWeighted> ensemble(ncc, data_weighted, sampler);
95    suite.err() << "build ensemble" << std::endl;
96    ensemble.build();
97    std::vector<std::vector<statistics::Averager> > result;
98    ensemble.predict(data_weighted, result);
99  }
100
101  suite.err() << "create KernelLookup" << std::endl;
102  classifier::KernelLookup kernel_lookup(kernel);
103  suite.err() << "create svm" << std::endl;
104  classifier::SVM svm;
105  suite.err() << "create Subsets" << std::endl;
106  classifier::CrossValidationSampler sampler(target,3,3);
107  classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
108                                                           kernel_lookup);
109
110  suite.err() << "create ensemble" << std::endl;
111  classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
112    ensemble(svm, kernel_lookup, sampler);
113  suite.err() << "build ensemble" << std::endl;
114  ensemble.build();
115  std::vector<std::vector<statistics::Averager> > result;
116  ensemble.predict(kernel_lookup, result);
117 
118  utility::Vector out(target.size(),0);
119  for (size_t i = 0; i<out.size(); ++i)
120    out(i)=ensemble.validate()[0][i].mean(); 
121  statistics::AUC roc;
122  suite.err() << roc.score(target,out) << std::endl;
123
124  {
125    suite.err() << "create ensemble" << std::endl;
126    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
127      ensemble(svm, kernel_lookup, sampler);
128    suite.err() << "test validate() before build()\n";
129    ensemble.validate();
130    std::vector<std::vector<statistics::Averager> > result;
131    suite.err() << "test predict() before build()\n";
132    ensemble.predict(kernel_lookup, result);
133  }
134  delete kf;
135
136  return suite.return_value();
137}
Note: See TracBrowser for help on using the repository browser.