Changeset 1865
- Timestamp:
- Mar 16, 2009, 12:36:36 PM (15 years ago)
- Location:
- trunk
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/test/svm_multi_class_test.cc
r1864 r1865 35 35 void test_construction(test::Suite&); 36 36 void test_predict(test::Suite&); 37 void test_predict2(test::Suite&); 37 38 38 39 int main( int argc, char* argv[]) … … 42 43 test_construction(suite); 43 44 test_predict(suite); 45 test_predict2(suite); 44 46 45 47 return suite.return_value(); … … 111 113 } 112 114 } 115 116 void test_predict2(test::Suite& suite) 117 { 118 using namespace classifier; 119 std::string file = test::filename("data/sorlie_centroid_data.txt"); 120 std::ifstream is(file.c_str()); 121 suite.err() << "load data `" << file << "'" << std::endl; 122 MatrixLookupWeighted data(is, '\t'); 123 is.close(); 124 PolynomialKernelFunction linear; 125 suite.err() << "calculating kernel" << std::endl; 126 Kernel_SEV kernel_raw(data, linear); 127 file = test::filename("data/sorlie_centroid_classes.txt"); 128 suite.err() << "load classes `" << file << "'" << std::endl; 129 is.open(file.c_str()); 130 Target target(is); 131 is.close(); 132 133 std::vector<size_t> index; 134 for (size_t i=0; i<50; ++i) 135 index.push_back(i); 136 for (size_t i=70; i<79; ++i) 137 index.push_back(i); 138 utility::Index train_index(index); 139 140 Target target_train(target, train_index); 141 KernelLookup kernel_train(kernel_raw, train_index, train_index); 142 SvmMultiClass svm; 143 suite.err() << "training svm" << std::endl; 144 svm.train(kernel_train, target_train); 145 146 index.clear(); 147 for (size_t i=50; i<70; ++i) 148 index.push_back(i); 149 utility::Index test_index(index); 150 151 KernelLookup kernel_test(kernel_raw, train_index, test_index); 152 utility::Matrix result; 153 suite.err() << "Predicting on test data" << std::endl; 154 svm.predict(kernel_test, result); 155 156 if (!suite.add(result.rows()==5 && result.columns()==20)) { 157 suite.err() << "ERROR: incorrect dimension in result Matrix\n" 158 << "found " << result.rows() << "x" << result.columns() << "\n" 159 << "expected 5x79\n"; 160 } 161 if (!suite.add(std::isnan(result(3, 0))) ) { 162 suite.err() << "ERROR: expected result(4,0) to be nan\n" 163 << " found " << result(4,0) << std::endl; 164 } 165 166 } 167 -
trunk/yat/classifier/SvmMultiClass.cc
r1861 r1865 69 69 std::numeric_limits<double>::quiet_NaN()); 70 70 for (size_t i=0; i<svm_.size(); ++i) { 71 yat::utility::Matrix tmp; 72 svm_[i].predict(input, tmp); 73 prediction.row_view(i) = tmp.row_const_view(0); 71 if (svm_[i].trained()) { 72 yat::utility::Matrix tmp; 73 svm_[i].predict(input, tmp); 74 prediction.row_view(i) = tmp.row_const_view(0); 75 } 74 76 } 75 77
Note: See TracChangeset
for help on using the changeset viewer.