source: branches/0.4-stable/test/ensemble_test.cc @ 1743

Last change on this file since 1743 was 1743, checked in by Peter, 12 years ago

updating copyright statements

  • Property svn:eol-style set to native
  • Property svn:keywords set to Id
File size: 4.9 KB
Line 
1// $Id: ensemble_test.cc 1743 2009-01-23 14:20:30Z peter $
2
3/*
4  Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
5  Copyright (C) 2007 Jari Häkkinen, Peter Johansson
6  Copyright (C) 2008 Peter Johansson, Markus Ringnér
7
8  This file is part of the yat library, http://dev.thep.lu.se/yat
9
10  The yat library is free software; you can redistribute it and/or
11  modify it under the terms of the GNU General Public License as
12  published by the Free Software Foundation; either version 2 of the
13  License, or (at your option) any later version.
14
15  The yat library is distributed in the hope that it will be useful,
16  but WITHOUT ANY WARRANTY; without even the implied warranty of
17  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
18  General Public License for more details.
19
20  You should have received a copy of the GNU General Public License
21  along with this program; if not, write to the Free Software
22  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
23  02111-1307, USA.
24*/
25
26#include "Suite.h"
27
28#include "yat/utility/Matrix.h"
29#include "yat/classifier/SubsetGenerator.h"
30#include "yat/classifier/CrossValidationSampler.h"
31#include "yat/classifier/EnsembleBuilder.h"
32#include "yat/classifier/Kernel.h"
33#include "yat/classifier/KernelLookup.h"
34#include "yat/classifier/Kernel_SEV.h"
35#include "yat/classifier/Kernel_MEV.h"
36#include "yat/classifier/MatrixLookup.h"
37#include "yat/classifier/MatrixLookupWeighted.h"
38#include "yat/classifier/NCC.h"
39#include "yat/classifier/PolynomialKernelFunction.h"
40#include "yat/classifier/SVM.h"
41#include "yat/statistics/AUC.h"
42#include "yat/statistics/EuclideanDistance.h"
43
44#include <cassert>
45#include <fstream>
46#include <iostream>
47#include <cstdlib>
48#include <limits>
49
50
51int main(int argc, char* argv[])
52{ 
53  using namespace theplu::yat;
54  test::Suite suite(argc, argv);
55 
56  suite.err() << "testing ensemble" << std::endl;
57
58  suite.err() << "loading data" << std::endl;
59  std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
60  utility::Matrix data_core(is);
61  is.close();
62
63  suite.err() << "create MatrixLookup" << std::endl;
64  classifier::MatrixLookup data(data_core);
65  classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction(); 
66  suite.err() << "Building kernel" << std::endl;
67  classifier::Kernel_SEV kernel(data,*kf);
68
69
70  suite.err() << "load target" << std::endl;
71  is.open(test::filename("data/nm_target_bin.txt").c_str());
72  classifier::Target target(is);
73  is.close();
74  assert(data.columns()==target.size());
75
76  {
77    suite.err() << "create ensemble of ncc" << std::endl;
78    classifier::NCC<statistics::EuclideanDistance> ncc;
79    classifier::CrossValidationSampler sampler(target,3,3);
80    classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
81    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
82      classifier::MatrixLookup> ensemble(ncc, data, sampler);
83    suite.err() << "build ensemble" << std::endl;
84    ensemble.build();
85    std::vector<std::vector<statistics::Averager> > result;
86    ensemble.predict(data, result);
87  }
88
89  {
90    suite.err() << "create ensemble of ncc" << std::endl;
91    classifier::MatrixLookupWeighted data_weighted(data);
92    classifier::NCC<statistics::EuclideanDistance> ncc;
93    classifier::CrossValidationSampler sampler(target,3,3);
94    classifier::SubsetGenerator<classifier::MatrixLookupWeighted> 
95      subdata(sampler,data_weighted);
96    classifier::EnsembleBuilder<classifier::SupervisedClassifier,
97      classifier::MatrixLookupWeighted> ensemble(ncc, data_weighted, sampler);
98    suite.err() << "build ensemble" << std::endl;
99    ensemble.build();
100    std::vector<std::vector<statistics::Averager> > result;
101    ensemble.predict(data_weighted, result);
102  }
103
104  suite.err() << "create KernelLookup" << std::endl;
105  classifier::KernelLookup kernel_lookup(kernel);
106  suite.err() << "create svm" << std::endl;
107  classifier::SVM svm;
108  suite.err() << "create Subsets" << std::endl;
109  classifier::CrossValidationSampler sampler(target,3,3);
110  classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
111                                                           kernel_lookup);
112
113  suite.err() << "create ensemble" << std::endl;
114  classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
115    ensemble(svm, kernel_lookup, sampler);
116  suite.err() << "build ensemble" << std::endl;
117  ensemble.build();
118  std::vector<std::vector<statistics::Averager> > result;
119  ensemble.predict(kernel_lookup, result);
120 
121  utility::Vector out(target.size(),0);
122  for (size_t i = 0; i<out.size(); ++i)
123    out(i)=ensemble.validate()[0][i].mean(); 
124  statistics::AUC roc;
125  suite.err() << roc.score(target,out) << std::endl;
126
127  {
128    suite.err() << "create ensemble" << std::endl;
129    classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup> 
130      ensemble(svm, kernel_lookup, sampler);
131    suite.err() << "test validate() before build()\n";
132    ensemble.validate();
133    std::vector<std::vector<statistics::Averager> > result;
134    suite.err() << "test predict() before build()\n";
135    ensemble.predict(kernel_lookup, result);
136  }
137  delete kf;
138
139  return suite.return_value();
140}
Note: See TracBrowser for help on using the repository browser.