00001 /*!@file Learn/LSVM.C Latent Support Vector Machine Classifier module */ 00002 00003 // //////////////////////////////////////////////////////////////////// // 00004 // The iLab Neuromorphic Vision C++ Toolkit - Copyright (C) 2001 by the // 00005 // University of Southern California (USC) and the iLab at USC. // 00006 // See http://iLab.usc.edu for information about this project. // 00007 // //////////////////////////////////////////////////////////////////// // 00008 // Major portions of the iLab Neuromorphic Vision Toolkit are protected // 00009 // under the U.S. patent ``Computation of Intrinsic Perceptual Saliency // 00010 // in Visual Environments, and Applications'' by Christof Koch and // 00011 // Laurent Itti, California Institute of Technology, 2001 (patent // 00012 // pending; application number 09/912,225 filed July 23, 2001; see // 00013 // http://pair.uspto.gov/cgi-bin/final/home.pl for current status). // 00014 // //////////////////////////////////////////////////////////////////// // 00015 // This file is part of the iLab Neuromorphic Vision C++ Toolkit. // 00016 // // 00017 // The iLab Neuromorphic Vision C++ Toolkit is free software; you can // 00018 // redistribute it and/or modify it under the terms of the GNU General // 00019 // Public License as published by the Free Software Foundation; either // 00020 // version 2 of the License, or (at your option) any later version. // 00021 // // 00022 // The iLab Neuromorphic Vision C++ Toolkit is distributed in the hope // 00023 // that it will be useful, but WITHOUT ANY WARRANTY; without even the // 00024 // implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // 00025 // PURPOSE. See the GNU General Public License for more details. // 00026 // // 00027 // You should have received a copy of the GNU General Public License // 00028 // along with the iLab Neuromorphic Vision C++ Toolkit; if not, write // 00029 // to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, // 00030 // Boston, MA 02111-1307 USA. // 00031 // //////////////////////////////////////////////////////////////////// // 00032 // 00033 // Primary maintainer for this file: Laurent Itti <itti@usc.edu> 00034 // $HeadURL: svn://isvn.usc.edu/software/invt/trunk/saliency/src/Learn/LSVM.C $ 00035 // $Id: LSVM.C 14581 2011-03-08 07:18:09Z dparks $ 00036 // 00037 00038 #include <fstream> 00039 #include <iostream> 00040 #include <iomanip> 00041 #include <string> 00042 #include <cstdlib> 00043 #include <map> 00044 00045 #include "svm.h" 00046 #include "LSVM.H" 00047 #include "Component/ModelComponent.H" 00048 #include "Component/ModelParam.H" 00049 #include "Component/OptionManager.H" 00050 00051 00052 LSVM::LSVM(float gamma, int C) 00053 { 00054 } 00055 00056 LSVM::~LSVM() 00057 { 00058 } 00059 00060 void LSVM::train(const std::vector<LabeledData>& examples) 00061 { 00062 seed_rand(); 00063 int count; 00064 DPMModel model; 00065 00066 LINFO("Sorting Examples"); 00067 std::vector<LabeledData> sortedExamples = std::sort(examples); 00068 00069 // find unique examples 00070 00071 // collapse examples 00072 // merge examples with identical labels 00073 //collapse(&X, sorted, num_unique); 00074 00075 // initial model 00076 00077 // lower bounds 00078 00079 // train 00080 LINFO("Training"); 00081 gradientDescent(C, J, X, w, lb, logdir, logtag); 00082 00083 // score examples 00084 LINFO("Scoring\n"); 00085 std::vector<double> score; 00086 for(size_t i=0; i<examples.size(); i++) 00087 score.push_back(getScore(model, w, examples[i])); 00088 00089 // compute loss and write it to a file 00090 LossInfo lossInfo = computeLoss(C, J, model, w); 00091 } 00092 00093 00094 00095 double LSVM::getScore(const LabeledData& data, Model& model, std::vector<double>& weight) { 00096 00097 double val = 0.0; 00098 for(uint i=0; i<data.features.size(); i++) 00099 val += weight[i]*data.features[i]; 00100 00101 return val; 00102 } 00103 00104 LSVM::LossInfo LSVM::computeLoss(double C, double J, data X, std::vector<double>& weight) { 00105 00106 LossInfo lossInfo; 00107 00108 double loss = 0; 00109 if (itsFullL2) // compute ||w||^2 00110 { 00111 for (int j = 0; j < X.numblocks; j++) { 00112 for (int k = 0; k < X.blocksizes[j]; k++) { 00113 loss += w[j][k] * w[j][k] * X.regmult[j]; 00114 } 00115 } 00116 } else { 00117 // compute max norm^2 component 00118 for (int c = 0; c < X.numcomponents; c++) { 00119 double val = 0; 00120 for (int i = 0; i < X.componentsizes[c]; i++) { 00121 int b = X.componentblocks[c][i]; 00122 double blockval = 0; 00123 for (int k = 0; k < X.blocksizes[b]; k++) 00124 blockval += w[b][k] * w[b][k] * X.regmult[b]; 00125 val += blockval; 00126 } 00127 if (val > loss) 00128 loss = val; 00129 } 00130 } 00131 loss *= 0.5; 00132 00133 // record the regularization term 00134 lossInfo.reg = loss; 00135 00136 // compute loss from the training data 00137 for (int l = 0; l < 2; l++) { 00138 // which label subset to look at: -1 or 1 00139 int subset = (l*2)-1; 00140 double subsetloss = 0.0; 00141 for (int i = 0; i < X.num; i++) { 00142 collapsed x = X.x[i]; 00143 00144 // only consider examples in the target subset 00145 char *ptr = x.seq[0]; 00146 if (LABEL(ptr) != subset) 00147 continue; 00148 00149 // compute max over latent placements 00150 int M = -1; 00151 double V = -INFINITY; 00152 for (int m = 0; m < x.num; m++) { 00153 double val = ex_score(x.seq[m], X, w); 00154 if (val > V) { 00155 M = m; 00156 V = val; 00157 } 00158 } 00159 00160 // compute loss on max 00161 ptr = x.seq[M]; 00162 int label = LABEL(ptr); 00163 double mult = C * (label == 1 ? J : 1); 00164 subsetloss += mult * max(0.0, 1.0-label*V); 00165 } 00166 loss += subsetloss; 00167 if (l==0) 00168 lossInfo.pos = subsetloss; 00169 else 00170 lossInfo.neg = subsetloss; 00171 } 00172 00173 lossInfo.loss = loss; 00174 00175 } 00176 00177 void LSVM::gradientDescent(double C, double J, data X, 00178 double **w, double **lb) { 00179 00180 int num = X.num; 00181 00182 // state for random permutations 00183 00184 // state for small cache 00185 double prev_loss = 1E9; 00186 00187 bool converged = false; 00188 int stop_count = 0; 00189 int t = 0; 00190 while (t < itsNumIter && !converged) { 00191 // pick random permutation 00192 for (int i = 0; i < num; i++) 00193 perm[i] = i; 00194 for (int swapi = 0; swapi < num; swapi++) { 00195 int swapj = (int)(drand48()*(num-swapi)) + swapi; 00196 int tmp = perm[swapi]; 00197 perm[swapi] = perm[swapj]; 00198 perm[swapj] = tmp; 00199 } 00200 00201 // count number of examples in the small cache 00202 int cnum = 0; 00203 for (int i = 0; i < num; i++) 00204 if (W[i] <= INCACHE) 00205 cnum++; 00206 00207 int numupdated = 0; 00208 for (int swapi = 0; swapi < num; swapi++) { 00209 // select example 00210 int i = perm[swapi]; 00211 00212 // skip if example is not in small cache 00213 if (W[i] > INCACHE) { 00214 W[i]--; 00215 continue; 00216 } 00217 00218 collapsed x = X.x[i]; 00219 00220 // learning rate 00221 double T = min(itsNumIterations/2.0, t + 10000.0); 00222 double rateX = cnum * C / T; 00223 00224 t++; 00225 //Evey 100000 itr show loss/stats and determin if we need to stop 00226 if (t % 100000 == 0) { 00227 //Compute the hinge loss 00228 LossInfo lossInfo = compute_loss(C, J, X, w); 00229 double delta = 1.0 - (fabs(prev_loss - loss) / loss); 00230 00231 LINFO("t=%i loss=%f delta=%f", t, lossInfo.loss, delta); 00232 00233 //Do we need to stop 00234 if (delta >= itsDeltaStop && t >= itsMinIter) { 00235 stop_count++; 00236 if (stop_count > itsStopCount) 00237 converged = true; 00238 } else if (stop_count > 0) { 00239 stop_count = 0; 00240 } 00241 prev_loss = lossInfo.loss; 00242 LINFO("%7.2f%% of max # iterations " 00243 "(delta = %.5f; stop count = %d)", 00244 100*double(t)/double(itsNumIter), 00245 max(delta, 0.0), 00246 itsStopCount - stop_count + 1); 00247 if (converged) 00248 break; 00249 } 00250 00251 // compute max over latent placements 00252 int M = -1; 00253 double V = -INFINITY; 00254 for (int m = 0; m < x.num; m++) { 00255 double val = ex_score(x.seq[m], X, w); 00256 if (val > V) { 00257 M = m; 00258 V = val; 00259 } 00260 } 00261 00262 //Compute the weights 00263 char *ptr = x.seq[M]; 00264 int label = LABEL(ptr); 00265 if (label * V < 1.0) { 00266 numupdated++; 00267 W[i] = 0; 00268 float *data = EX_DATA(ptr); 00269 int blocks = NUM_NONZERO(ptr); 00270 for (int j = 0; j < blocks; j++) { 00271 int b = BLOCK_IDX(data); 00272 double mult = (label > 0 ? J : -1) * rateX * X.learnmult[b]; 00273 data++; 00274 for (int k = 0; k < X.blocksizes[b]; k++) 00275 w[b][k] += mult * data[k]; 00276 data += X.blocksizes[b]; 00277 } 00278 } else { 00279 if (W[i] == INCACHE) 00280 W[i] = MINWAIT + (int)(drand48()*50); 00281 else 00282 W[i]++; 00283 } 00284 00285 // periodically regularize the model 00286 if (t % REGFREQ == 0) { 00287 // apply lowerbounds 00288 for (int j = 0; j < X.numblocks; j++) 00289 for (int k = 0; k < X.blocksizes[j]; k++) 00290 w[j][k] = max(w[j][k], lb[j][k]); 00291 00292 double rateR = 1.0 / T; 00293 00294 if (itsFullL2) 00295 { 00296 // update model 00297 for (int j = 0; j < X.numblocks; j++) { 00298 double mult = rateR * X.regmult[j] * X.learnmult[j]; 00299 mult = pow((1-mult), REGFREQ); 00300 for (int k = 0; k < X.blocksizes[j]; k++) { 00301 w[j][k] = mult * w[j][k]; 00302 } 00303 } 00304 } else { 00305 // assume simple mixture model 00306 int maxc = 0; 00307 double bestval = 0; 00308 for (int c = 0; c < X.numcomponents; c++) { 00309 double val = 0; 00310 for (int i = 0; i < X.componentsizes[c]; i++) { 00311 int b = X.componentblocks[c][i]; 00312 double blockval = 0; 00313 for (int k = 0; k < X.blocksizes[b]; k++) 00314 blockval += w[b][k] * w[b][k] * X.regmult[b]; 00315 val += blockval; 00316 } 00317 if (val > bestval) { 00318 maxc = c; 00319 bestval = val; 00320 } 00321 } 00322 for (int i = 0; i < X.componentsizes[maxc]; i++) { 00323 int b = X.componentblocks[maxc][i]; 00324 double mult = rateR * X.regmult[b] * X.learnmult[b]; 00325 mult = pow((1-mult), REGFREQ); 00326 for (int k = 0; k < X.blocksizes[b]; k++) 00327 w[b][k] = mult * w[b][k]; 00328 } 00329 } 00330 } 00331 } 00332 } 00333 00334 if (converged) 00335 LINFO("Termination criteria reached after %d iterations.\n", t); 00336 else 00337 LINFO("Max iteration count reached.\n", t); 00338 } 00339