00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Image/DrawOps.H"
00039 #include "Image/MathOps.H"
00040 #include "Neuro/StdBrain.H"
00041 #include "Raster/Raster.H"
00042 #include "Gist/trainUtils.H"
00043
00044 #include <cstdio>
00045
00046
00047
00048 FFNtrainInfo::FFNtrainInfo(std::string fName)
00049 {
00050 if(fName.length() != 0) reset(fName);
00051 }
00052
00053
00054
00055 FFNtrainInfo::~FFNtrainInfo()
00056 {}
00057
00058
00059
00060 bool FFNtrainInfo::reset(std::string fName)
00061 {
00062 FILE *fp; char inLine[100]; char comment[200]; char temp[200];
00063
00064
00065 if((fp = fopen(fName.c_str(),"rb")) == NULL)
00066 {
00067 LINFO("training file: %s not found",fName.c_str());
00068 return false;
00069 }
00070
00071
00072 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00073 trainFolder = std::string(temp);
00074
00075
00076 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00077 testFolder = std::string(temp);
00078
00079
00080 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%d %s", &nOutput, comment);
00081
00082
00083 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00084 isPCA = (strcmp(temp,"PCA") == 0);
00085 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00086 evecFname = std::string(temp);
00087
00088
00089 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%d %s", &oriFeatSize, comment);
00090 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%d %s", &redFeatSize, comment);
00091
00092
00093 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%d %s", &h1size, comment);
00094 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%d %s", &h2size, comment);
00095
00096
00097 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%f %s", &learnRate, comment);
00098
00099
00100 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00101 trainSampleFile = std::string(temp);
00102
00103
00104 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00105 testSampleFile = std::string(temp);
00106
00107
00108 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00109 h1Name = trainFolder + std::string(temp);
00110 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00111 h2Name = trainFolder + std::string(temp);
00112 if (fgets(inLine, 1000, fp) == NULL) LFATAL("fgets() failed"); sscanf(inLine, "%s %s", temp, comment);
00113 oName = trainFolder + std::string(temp);
00114
00115 fclose(fp);
00116
00117
00118 LINFO("Training folder: %s", trainFolder.c_str());
00119 LINFO("Testing folder: %s", testFolder.c_str());
00120 LINFO("PCA?: %d (%s) %d -> %d",
00121 isPCA, evecFname.c_str(), oriFeatSize, redFeatSize);
00122 LINFO("NN: %d->%d->%d->%d LR: %f",
00123 redFeatSize, h1size, h2size, nOutput, learnRate);
00124 LINFO("train: %s", trainSampleFile.c_str());
00125 LINFO("test : %s", testSampleFile.c_str());
00126 LINFO("h1 weight file name: %s", h1Name.c_str());
00127 LINFO("h2 weight file name: %s", h2Name.c_str());
00128 LINFO("o weight file name: %s", oName.c_str());
00129
00130 return true;
00131 }
00132
00133
00134
00135
00136
00137 Image<double> setupPcaIcaMatrix(std::string inW, int oriSize, int redSize)
00138 {
00139 FILE *fp;
00140
00141
00142 Image<double> ret(oriSize, redSize, NO_INIT);
00143
00144
00145 if((fp = fopen(inW.c_str(),"rb")) == NULL)
00146 {
00147 LINFO("can't open pca file: %s fill with random values",
00148 inW.c_str());
00149 Image<double>::iterator aptr = ret.beginw();
00150 for(int i = 0; i < redSize; i++)
00151 {
00152 for(int j = 0; j < oriSize; j++)
00153 {
00154 *aptr++ = (-TUTILS_RW_RANGE/2.0) +
00155 (rand()/(RAND_MAX + 1.0) * TUTILS_RW_RANGE);
00156 }
00157 }
00158 }
00159 else
00160 {
00161 Image<double>::iterator aptr = ret.beginw();
00162 for(int i = 0; i < redSize; i++)
00163 {
00164 for(int j = 0; j < oriSize; j++)
00165 { double val; if (fread(&val,sizeof(double),1,fp) != 1) LFATAL("fread error"); *aptr++ = val; }
00166 }
00167 }
00168
00169 LINFO("PCA/ICA un-mixing matrix is set");
00170 return ret;
00171 }
00172
00173
00174
00175 Image<float> getPcaIcaFeatImage(Image<double> res, int w, int h, int s)
00176 {
00177 Image<float> img(w * s, h * s, ZEROS);
00178
00179 for(int j = 0; j < h; j++)
00180 for(int i = 0; i < w; i++)
00181 drawPatch(img, Point2D<int>(i*s+s/2,j*s+s/2),s/2, float(res.getVal(j,i)));
00182
00183 return img;
00184 }
00185
00186
00187
00188
00189
00190