00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129 #include "Component/ModelManager.H"
00130 #include "Gist/FFN.H"
00131 #include "Raster/Raster.H"
00132 #include "Util/MathFunctions.H"
00133 #include "Gist/trainUtils.H"
00134 #include "Image/MatrixOps.H"
00135
00136 #include <vector>
00137
00138 #define ABSOLUTE 0
00139 #define MIXTURE 1
00140 #define ERR_THRESHOLD .01
00141 #define MAX_EPOCH 1000
00142
00143
00144 void setupCases
00145 (std::string folder, std::string fname, bool equalize);
00146 void train();
00147 void test();
00148 void run(int isTest);
00149 void diff
00150 (Image<double> ideal, Image<double> out, double &tErr, int &tFc, int &tIc);
00151
00152
00153
00154 rutz::shared_ptr<FeedForwardNetwork> ffn;
00155 int nSamples = 0;
00156 std::vector<Image<double> > in;
00157 std::vector<Image<double> > out;
00158 Image<double> pcaIcaMatrix;
00159
00160
00161 rutz::shared_ptr<FFNtrainInfo> info;
00162
00163
00164
00165 int main(const int argc, const char **argv)
00166 {
00167
00168 ModelManager manager("Feed-Forward Network trainer");
00169
00170
00171 if (manager.parseCommandLine(argc, argv, "<input_train.txt>",
00172 1, 1) == false)
00173 return(1);
00174
00175
00176 info.reset(new FFNtrainInfo(manager.getExtraArg(0)));
00177
00178
00179
00180 ffn.reset(new FeedForwardNetwork());
00181 ffn->init3L(info->h1Name, info->h2Name, info->oName,
00182 info->redFeatSize, info->h1size, info->h2size, info->nOutput,
00183 info->learnRate, 0.0);
00184
00185
00186 if(info->isPCA)
00187 {
00188 pcaIcaMatrix = setupPcaIcaMatrix
00189 (info->trainFolder+info->evecFname,
00190 info->oriFeatSize, info->redFeatSize);
00191 }
00192
00193
00194 printf("would you like to skip training and just test the network? "
00195 "(y/n - default y)");
00196 char spC = getchar(); getchar();
00197 if(spC == 'n')
00198 {
00199 printf("would you like to equalize the number of samples? "
00200 "(y/n default y)");
00201 char spC = getchar(); getchar();
00202 bool equalize = true;
00203 if(spC == 'n') equalize = false;
00204
00205 setupCases(info->trainFolder, info->trainSampleFile, equalize);
00206 train();
00207 Raster::waitForKey();
00208 }
00209
00210
00211 setupCases(info->testFolder, info->testSampleFile, false);
00212 test();
00213 Raster::waitForKey();
00214
00215
00216 ffn->write3L(info->h1Name, info->h2Name, info->oName);
00217 }
00218
00219
00220
00221 void setupCases(std::string folder, std::string fname, bool equalize)
00222 {
00223 char comment[200]; FILE *fp; char inLine[100];
00224
00225
00226 std::string name = folder + fname;
00227 if((fp = fopen(name.c_str(),"rb")) == NULL)
00228 {
00229 LINFO("samples file: %s not found", name.c_str());
00230
00231
00232 out.resize(0);
00233 in.resize(0);
00234 nSamples = 0;
00235
00236 return;
00237 }
00238 LINFO("tName: %s",name.c_str());
00239
00240
00241 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets failed"); sscanf(inLine, "%d %s", &nSamples, comment);
00242
00243
00244 uint tNout;
00245 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets failed"); sscanf(inLine, "%d %s", &tNout, comment);
00246 if(tNout != info->nOutput)
00247 LFATAL("Num categories differ: %d != %d", tNout, info->nOutput);
00248
00249
00250 char gtOpt[100]; int gtType = -1;
00251 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets failed"); sscanf(inLine, "%s %s", gtOpt, comment);
00252 if(strcmp(gtOpt,"ABSOLUTE") == 0)
00253 gtType = ABSOLUTE;
00254 else if(strcmp(gtOpt,"MIXTURE" ) == 0)
00255 gtType = MIXTURE;
00256 else
00257 LFATAL("unknown ground truth type %s",gtOpt);
00258
00259
00260 out.resize(nSamples);
00261 in.resize(nSamples);
00262
00263
00264 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets failed");
00265
00266 char cName[100]; char sName[100]; char iName[100]; char ext[100];
00267 int cStart, cNum; int gTruth;
00268 FILE *ifp;
00269 int count = 0; int tSamples = 0;
00270 std::vector<uint> nSamples;
00271 while(fgets(inLine, 1000, fp) != NULL)
00272 {
00273 if(gtType == ABSOLUTE)
00274 {
00275
00276 sscanf(inLine, "%s %d %d %d %s", cName, &cStart, &cNum, &gTruth, ext);
00277 sprintf(sName,"%s%s", folder.c_str(), cName);
00278 printf(" sName: %s %d %d %d %s\n",sName, cStart, cNum, gTruth, ext);
00279 }
00280 else if(gtType == MIXTURE)
00281 {
00282
00283
00284
00285
00286
00287
00288
00289 LFATAL("MIXTURE ground truth type not yet implemented");
00290 }
00291 else LFATAL("unknown ground truth type %s",gtOpt);
00292
00293 nSamples.push_back(cNum);
00294
00295
00296 for(int j = cStart; j < cStart+cNum; j++)
00297 {
00298 tSamples++;
00299
00300 sprintf(iName,"%s%06d%s", sName,j,ext);
00301
00302
00303 if((ifp = fopen(iName,"rb")) != NULL)
00304 {
00305 Image<double> tData(1,info->oriFeatSize, NO_INIT);
00306 Image<double>::iterator aptr = tData.beginw();
00307
00308 for(int i = 0; i < tData.getSize(); i++)
00309 {
00310 double val; if (fread(&val, sizeof(double), 1, ifp) != 1) LFATAL("fread failed");
00311 *aptr++ = val;
00312 }
00313
00314 LINFO("feature file found: %s (%d)",
00315 iName,gTruth);
00316 fclose(ifp);
00317
00318
00319 if(info->isPCA) in[count] = matrixMult(pcaIcaMatrix, tData);
00320 else in[count] = tData;
00321
00322
00323 if(gtType == ABSOLUTE)
00324 {
00325 Image<double> res(1,info->nOutput, ZEROS);
00326 res.setVal(0, gTruth, 1.0);
00327 out[count] = res;
00328 }
00329 else if(gtType == MIXTURE)
00330 {
00331 LFATAL("MIXTURE ground truth type not yet implemented");
00332 }
00333 else LFATAL("unknown ground truth type %s",gtOpt);
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347 count++;
00348 }
00349 else LFATAL("file: %s not found\n",iName);
00350 }
00351 }
00352
00353
00354 if(equalize)
00355 {
00356
00357 uint max = 0;
00358
00359
00360 max = *max_element(nSamples.begin(),nSamples.end());
00361 LINFO("max element: %d", max);
00362
00363 uint offset = 0;
00364 for(uint i = 0; i < nSamples.size(); i++)
00365 {
00366 LINFO("extra samples for class[%3d]: %d - %d -> %d",
00367 i, max, nSamples[i], max - nSamples[i]);
00368 for(uint j = 0; j < max - nSamples[i]; j++)
00369 {
00370
00371 uint index = rand()/(RAND_MAX + 1.0) * nSamples[i];
00372 LINFO("[%d] Duplicating class[%3d] sample[%3d]"
00373 " -> actual ind: %3d",
00374 j, i, index, index + offset);
00375 index = index + offset;
00376
00377 in.push_back(in[index]);
00378 out.push_back(out[index]);
00379 }
00380 offset += nSamples[i];
00381 }
00382 LINFO("Total samples before equalized: %d \n",tSamples);
00383 tSamples = in.size();
00384 }
00385
00386 LINFO("Actual total samples: %d \n",tSamples);
00387 fclose(fp);
00388 }
00389
00390
00391
00392 void train() { run(0);};
00393
00394
00395
00396 void test() { run(1);};
00397
00398
00399
00400 void run(int isTest)
00401 {
00402 LINFO("Run the samples");
00403 double errSum = double(nSamples);
00404 double err; Image<double> ffnOut;
00405 int nfc = nSamples; int fc;
00406 int nfcClass[info->nOutput][info->nOutput];
00407 int nTrials = 0;
00408 int target = 0;
00409
00410 if(nSamples == 0) return;
00411 int order[nSamples];
00412 for(int i = 0; i < nSamples; i++) order[i] = i;
00413
00414 while(nTrials < MAX_EPOCH && !isTest && nfc > int(nSamples*ERR_THRESHOLD))
00415 {
00416
00417 for(uint i = 0; i < info->nOutput; i++)
00418 for(uint j = 0; j < info->nOutput; j++)
00419 nfcClass[i][j] = 0;
00420 errSum = 0.0; nfc = 0;
00421
00422
00423 randShuffle(order, nSamples);
00424
00425 for(int i = 0; i < nSamples; i++)
00426 {
00427
00428 ffn->run3L(in[order[i]]);
00429 ffnOut = ffn->getOutput();
00430
00431
00432 diff(out[order[i]], ffnOut, err, fc, target);
00433
00434
00435 if(fc != -1)
00436 {
00437 nfc++;
00438 nfcClass[target][fc]++;
00439 }
00440 else
00441 nfcClass[target][target]++;
00442
00443
00444 errSum += err;
00445
00446 if(fc != -1)
00447 {
00448
00449 ffn->backprop3L(out[order[i]]);
00450
00451 }
00452 }
00453 nTrials++;
00454
00455
00456 if(nTrials %1 == 0)
00457 {
00458 printf("Trial_%04d_Err: %f nfc: %5d/%5d -> %f%%\n",
00459 nTrials, errSum/nSamples,
00460 nfc,nSamples,(double)(nfc)/(0.0 + nSamples)*100.0);
00461
00462 printf("class |");
00463 for(uint k = 0; k < info->nOutput; k++)
00464 printf(" %4d", k);
00465 printf("\n");
00466 for(uint k = 0; k < info->nOutput; k++)
00467 printf("------");
00468 printf("\n");
00469 for(uint k = 0; k < info->nOutput; k++)
00470 {
00471 printf("%6d|",k);
00472 for(uint j = 0; j < info->nOutput; j++)
00473 printf(" %4d",nfcClass[k][j]);
00474 printf("\n");
00475 }
00476 }
00477 printf("\n");
00478 }
00479
00480
00481 if(isTest)
00482 {
00483 nfc = 0; errSum = 0.0; err = 0;
00484 for(uint i = 0; i < info->nOutput; i++)
00485 for(uint j = 0; j < info->nOutput; j++)
00486 nfcClass[i][j] = 0;
00487
00488 for(int i = 0; i < nSamples; i++)
00489 {
00490
00491 ffn->run3L(in[i]);
00492
00493
00494 ffnOut = ffn->getOutput();
00495
00496
00497 diff(out[i], ffnOut, err, fc, target);
00498
00499
00500 if(fc != -1)
00501 {
00502 nfc++;
00503 nfcClass[target][fc]++;
00504 }
00505 else
00506 nfcClass[target][target]++;
00507
00508
00509 errSum += err;
00510
00511 if((fc != -1) | 1)
00512 {
00513 printf("sample %5d: ",i);
00514 for(uint j = 0; j < info->nOutput; j++)
00515 printf("%.3f ",out[i][j]);
00516 printf(" -:- ");
00517 for(uint j = 0; j < info->nOutput; j++)
00518 printf("%.3f ",ffnOut[j]);
00519 }
00520 if(fc != -1) printf(" WRONG! NO:%d [%d][%d] = %d \n",
00521 nfc, target, fc, nfcClass[target][fc]);
00522 else printf("\n");
00523 }
00524 }
00525
00526
00527 printf("Final Trial_%04d_Err: %f nfc: %5d/%5d -> %.3f%%\n",
00528 nTrials,errSum/nSamples,
00529 nfc,nSamples,(double)(nfc)/(0.0 + nSamples)*100.0);
00530
00531 printf("class |");
00532 for(uint k = 0; k < info->nOutput; k++)
00533 printf(" %5d",k);
00534 printf(" Total pct. err \n-------");
00535 for(uint k = 0; k < info->nOutput; k++)
00536 printf("------");
00537 printf("\n");
00538 for(uint k = 0; k < info->nOutput; k++)
00539 {
00540 int t = 0, e = 0;
00541 printf("%6d|",k);
00542 for(uint j = 0; j < info->nOutput; j++)
00543 {
00544 printf(" %5d",nfcClass[k][j]);
00545 if(k == j)
00546 t = nfcClass[k][j];
00547 else
00548 e += nfcClass[k][j];
00549 }
00550 if(e+t == 0)
00551 printf(" %6d/%6d N/A%%\n",0,0);
00552 else
00553 printf(" %6d/%6d %6.2f%%\n",e,e+t, float(e)/float(e+t)*100.0);
00554 }
00555
00556 for(uint k = 0; k < info->nOutput; k++)
00557 printf("------");
00558 printf("-------\nFalse+|");
00559 for(uint k = 0; k < info->nOutput; k++)
00560 {
00561 int e = 0;
00562 for(uint j = 0; j < info->nOutput; j++)
00563 {
00564 if(k == j)
00565 ;
00566 else
00567 e += nfcClass[j][k];
00568 }
00569 printf(" %5d",e);
00570 }
00571 printf("\ntotal |");
00572 for(uint k = 0; k < info->nOutput; k++)
00573 {
00574 int t = 0, e = 0;
00575 for(uint j = 0; j < info->nOutput; j++)
00576 {
00577 if(k == j)
00578 t = nfcClass[j][k];
00579 else
00580 e += nfcClass[j][k];
00581 }
00582 printf(" %5d",e+t);
00583 }
00584 printf("\nerr: |");
00585 for(uint k = 0; k < info->nOutput; k++)
00586 {
00587 int t = 0, e = 0;
00588 for(uint j = 0; j < info->nOutput; j++)
00589 {
00590 if(k == j)
00591 t = nfcClass[j][k];
00592 else
00593 e += nfcClass[j][k];
00594 }
00595 if(e+t == 0)
00596 printf(" N/A");
00597 else
00598 printf(" %5.2f",float(e)/float(e+t)*100.0);
00599 }
00600 printf("\n");
00601 }
00602
00603
00604
00605 void diff
00606 (Image<double> ideal, Image<double> out,
00607 double &tErr, int &tFc, int &tIc)
00608 {
00609 tErr = 0.0;
00610 Image<double>::iterator iptr = ideal.beginw();
00611 Image<double>::iterator optr = out.beginw();
00612 for(uint i = 0; i < info->nOutput; i++)
00613 tErr += fabs(*iptr++ - *optr++);
00614
00615 int iMaxI = 0, oMaxI = 0;
00616 iptr = ideal.beginw(); optr = out.beginw();
00617 double iMax = *iptr++, oMax = *optr++;
00618 for(uint i = 1; i < info->nOutput; i++)
00619 {
00620 double ival = *iptr++;
00621 double oval = *optr++;
00622 if(ival > iMax) { iMax = ival; iMaxI = i; }
00623 if(oval > oMax) { oMax = oval; oMaxI = i; }
00624 }
00625
00626
00627 tFc = -1; if(iMaxI != oMaxI) tFc = oMaxI;
00628 tIc = iMaxI;
00629 }
00630
00631
00632
00633
00634
00635
00636
00637
00638
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660
00661
00662
00663
00664
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674
00675
00676
00677
00678
00679
00680
00681