test-hmax2.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "HMAX/Hmax.H"
00039 #include "Image/Image.H"
00040 #include "Image/MathOps.H"
00041 #include "Raster/Raster.H"
00042 #include "Util/Timer.H"
00043 #include "Util/Types.H"
00044 #include "Util/log.H"
00045
00046 #include <fstream>
00047 #include <iostream>
00048
00049 int tItr = 1;
00050 int c2Size = 0;
00051 int target = 0;
00052 std::string targetS1 = "tri";
00053 std::string targetS2 = "Tri";
00054 int imgHeight = 0;
00055 int imgWidth = 0;
00056 int trnFlag = 0;
00057
00058 float w[] = {1,1};
00059
00060
00061
00062 #define NORI 4
00063
00064 int main(int argc, char **argv)
00065 {
00066
00067 char imname[1024]; imname[0] = '\0';
00068 strncpy(imname, argv[1], 1023);
00069 trnFlag = atoi(argv[2]);
00070 std::string inName = imname;
00071
00072 if (argc != 3)
00073 { std::cerr<<"USAGE: test-hmax2 <dir> <trnFlag(1 or 0)>"<<std::endl; exit(1); }
00074
00075
00076
00077 float eta = 0.3;
00078 std::ofstream wFile("weightFile");
00079 std::ofstream c2File("c2ResSum");
00080 Image<float> wt;
00081 Image<float> wtC2res;
00082
00083 Image<byte> input;
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099 std::vector<int> scss(5);
00100 scss[0] = 0; scss[1] = 2; scss[2] = 5; scss[3] = 8; scss[4] = 12;
00101 std::vector<int> spss(4);
00102 spss[0] = 4; spss[1] = 6; spss[2] = 9; spss[3] = 12;
00103 Hmax hmax(NORI, spss, scss);
00104 Timer tim; tim.reset();
00105
00106 if (trnFlag == 0) tItr = 1;
00107
00108
00109 std::vector<std::string> fileList = hmax.readDir(inName);
00110 int listSize = fileList.size();
00111 Image<float> oc2resp[listSize];
00112
00113
00114 for(int imgInd = 0; imgInd < listSize; imgInd++) {
00115 input = Raster::ReadGray(fileList[imgInd], RASFMT_PNM);
00116 Image<float> inputf(input);
00117 oc2resp[imgInd]=hmax.origGetC2(inputf);
00118
00119 imgHeight = inputf.getHeight();
00120 imgWidth = inputf.getWidth();
00121 c2Size = oc2resp[imgInd].getHeight();
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135 if (c2File.is_open()) {
00136 c2File << sum(oc2resp[imgInd]) << ", ";
00137 if (imgInd%10 == 0)
00138 c2File << std::endl;
00139 }
00140 }
00141
00142
00143 wt.resize(c2Size, c2Size, true);
00144 srand(time(NULL));
00145 for(int y = 0; y < c2Size; y++) {
00146 for(int x = 0; x < c2Size; x++) {
00147 float r = rand()%10;
00148 wt.setVal(x,y,r/10);
00149 }
00150 }
00151
00152 for(int itr = 0; itr < tItr; itr++) {
00153 eta = eta - (itr * (9*eta)/(10*tItr) );
00154 for(int imgInd = 0; imgInd < listSize; imgInd++) {
00155
00156 if(trnFlag == 1) {
00157 float udWt = 0;
00158
00159 std::string::size_type where1 = fileList[imgInd].find(targetS1);
00160 std::string::size_type where2 = fileList[imgInd].find(targetS2);
00161 if ((where1 == std::string::npos) && (where2 == std::string::npos))
00162 target = -1;
00163 else target = 1;
00164
00165
00166
00167 for(int y = 0; y < c2Size; y++) {
00168 for(int x = 0; x < c2Size; x++) {
00169 if (target == 1)
00170 udWt = wt.getVal(x,y) * (1 + eta*oc2resp[imgInd].getVal(x,y));
00171 else if (target == -1)
00172 udWt = wt.getVal(x,y) * (1 - eta*oc2resp[imgInd].getVal(x,y));
00173 wt.setVal(x,y,udWt);
00174 }
00175 }
00176
00177 inplaceNormalize(wt, 0.0f, 1.0f);
00178 }
00179
00180
00181 else {
00182 wtC2res.resize(c2Size, c2Size, true);
00183 for(int y = 0; y < 16; y ++) {
00184 for(int x = 0; x < 16; x ++)
00185 wtC2res.setVal(x,y,w[(y*16) + x]);
00186 }
00187 wt = wtC2res * oc2resp[imgInd];
00188 }
00189
00190
00191 if (wFile.is_open()) {
00192
00193 wFile <<"## img.." <<fileList[imgInd]<<" itr "<<itr<<" eta "<<eta<<std::endl;
00194 for(int y = 0; y < 16; y ++) {
00195 for(int x = 0; x < 16; x ++)
00196 wFile << wt.getVal(x, y) << ", ";
00197 }
00198 wFile << std::endl;
00199 wFile << "the sum ... "<<sum(wt)<<" the mean "<<mean(wt)<<std::endl;
00200
00201
00202
00203
00204
00205
00206
00207 }
00208 }
00209 }
00210
00211
00212
00213
00214
00215
00216 return 0;
00217 }
00218
00219
00220
00221
00222
00223