00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #include "rutz/shared_ptr.h"
00037 #include "Gist/FFN.H"
00038 #include "Raster/Raster.H"
00039 #include "Image/MatrixOps.H"
00040 #include "Gist/trainUtils.H"
00041 #include <string>
00042 #include <vector>
00043
00044 int main(const int argc, const char **argv)
00045 {
00046 LINFO("Testing with small Neural Network");
00047
00048
00049 rutz::shared_ptr<FeedForwardNetwork> ffn(new FeedForwardNetwork());
00050
00051 Image<double> wh(10,7, NO_INIT);
00052 wh.setVal(0,0, 1.0); wh.setVal(0,1, 1.0); wh.setVal(0,2, 1.0);
00053 wh.setVal(1,0, 1.0); wh.setVal(1,1, 1.0); wh.setVal(1,2, 1.0);
00054 wh.setVal(2,0, 1.0); wh.setVal(2,1, 1.0); wh.setVal(2,2, 1.0);
00055
00056 Image<double> wh2(8,5, NO_INIT);
00057 wh2.setVal(0,0, 1.0); wh2.setVal(0,1, 1.0); wh2.setVal(0,2, 1.0);
00058 wh2.setVal(1,0, 1.0); wh2.setVal(1,1, 1.0); wh2.setVal(1,2, 1.0);
00059 wh2.setVal(2,0, 1.0); wh2.setVal(2,1, 1.0); wh2.setVal(2,2, 1.0);
00060
00061 Image<double> wo(6,3, NO_INIT);
00062 wo.setVal(0,0, 1.0); wo.setVal(0,1, 1.0);
00063 wo.setVal(1,0, 1.0); wo.setVal(1,1, 1.0);
00064 wo.setVal(2,0, 1.0); wo.setVal(2,1, 1.0);
00065 wo.setVal(3,0, 1.0); wo.setVal(3,1, 1.0);
00066
00067 ffn->init3L(wh, wh2, wo, 0.1, 0.0);
00068
00069
00070 Image<double> in(1,9, NO_INIT);
00071 in.setVal(0, 0, 1.0);
00072 in.setVal(0, 1, 2.0);
00073
00074
00075 Image<double> out = ffn->run3L(in);
00076 for(int i = 0; i < out.getSize(); i++)
00077 LINFO("%d: %f", i, out.getVal(0,i));
00078
00079
00080 Image<double> target(1,3, NO_INIT);
00081 target.setVal(0, 0, 0.5);
00082 target.setVal(0, 1, 0.2);
00083 target.setVal(0, 2, 0.7);
00084 ffn->backprop3L(target);
00085
00086
00087 out = ffn->run3L(in);
00088 for(int i = 0; i < out.getHeight(); i++)
00089 LINFO("%d: %f", i, out.getVal(i));
00090
00091
00092
00093 Raster::waitForKey();
00094 LINFO("Now testing with provided files");
00095
00096 rutz::shared_ptr<FeedForwardNetwork> ffn2(new FeedForwardNetwork());
00097 std::string infoFName("../data/PAMI07data/ACB_GIST_train.txt");
00098 FFNtrainInfo pcInfo(infoFName);
00099
00100 ffn2->init3L(pcInfo.h1Name, pcInfo.h2Name, pcInfo.oName,
00101 pcInfo.redFeatSize, pcInfo.h1size, pcInfo.h2size,
00102 pcInfo.nOutput, 0.0, 0.0);
00103
00104
00105 Image<double> pcaVec =
00106 setupPcaIcaMatrix(pcInfo.trainFolder+pcInfo.evecFname,
00107 pcInfo.oriFeatSize, pcInfo.redFeatSize);
00108
00109
00110 Image<double> oriin(1,pcInfo.oriFeatSize,NO_INIT);
00111 Image<double>::iterator aptr = oriin.beginw();
00112 FILE *gfp;
00113 std::string gfname = pcInfo.trainFolder + std::string("ACB1B_000.gist");
00114 if((gfp = fopen(gfname.c_str(),"rb")) != NULL)
00115 {
00116 LINFO("gist file found: %s", gfname.c_str());
00117 for(uint i = 0; i < pcInfo.oriFeatSize; i++)
00118 {
00119 double tval; if (fread(&tval, sizeof(double), 1, gfp) != 1) LFATAL("fread failed");
00120 *aptr++ = tval;
00121 }
00122 fclose(gfp);
00123 }
00124 else LFATAL("gist file NOT found: %s", gfname.c_str());
00125
00126
00127 out = ffn2->run3L(matrixMult(pcaVec,oriin));
00128 for(int i = 0; i < out.getHeight(); i++)
00129 LINFO("%d: %f", i, out.getVal(i));
00130 }
00131
00132
00133
00134
00135
00136