00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 #include "Component/ModelManager.H"
00041 #include "Image/Image.H"
00042 #include "Image/ImageSet.H"
00043 #include "Image/ShapeOps.H"
00044 #include "Image/CutPaste.H"
00045 #include "Image/DrawOps.H"
00046 #include "Image/FilterOps.H"
00047 #include "Image/ColorOps.H"
00048 #include "Image/Transforms.H"
00049 #include "Image/MathOps.H"
00050 #include "Neuro/StdBrain.H"
00051 #include "Neuro/VisualCortex.H"
00052 #include "Neuro/NeuroOpts.H"
00053 #include "Media/FrameSeries.H"
00054 #include "Transport/FrameInfo.H"
00055 #include "Raster/GenericFrame.H"
00056 #include "Media/TestImages.H"
00057 #include "Media/SceneGenerator.H"
00058 #include "Media/MediaSimEvents.H"
00059 #include "Channels/DescriptorVec.H"
00060 #include "Channels/ComplexChannel.H"
00061 #include "Channels/SubmapAlgorithmBiased.H"
00062 #include "Simulation/SimEventQueue.H"
00063 #include "Simulation/SimulationOpts.H"
00064 #include "Simulation/SimEventQueueConfigurator.H"
00065 #include "Neuro/NeuroSimEvents.H"
00066 #include "GUI/DebugWin.H"
00067 #include "ObjRec/MaskBiaser.H"
00068 #include "ObjRec/ObjRecSalBayes.H"
00069
00070
00071 ObjRecSalBayes::ObjRecSalBayes(ModelManager& mgr, const std::string& descrName,
00072 const std::string& tagName) :
00073 ModelComponent(mgr, descrName, tagName),
00074 itsDebug(false),
00075 itsQ(mgr)
00076
00077 {
00078
00079 itsStdBrain = nub::soft_ref<StdBrain>(new StdBrain(mgr));
00080 addSubComponent(itsStdBrain);
00081
00082 mgr.setOptionValString(&OPT_RawVisualCortexChans, "IOC");
00083 mgr.setOptionValString(&OPT_SaliencyMapType, "Fast");
00084 mgr.setOptionValString(&OPT_SMfastInputCoeff, "1");
00085 mgr.setOptionValString(&OPT_TaskRelevanceMapType, "None");
00086
00087 mgr.setOptionValString(&OPT_WinnerTakeAllType, "Fast");
00088
00089
00090
00091
00092
00093 mgr.setOptionValString(&OPT_IORtype, "Disc");
00094
00095 itsFoveaRadius = 50;
00096
00097 itsMgr = &mgr;
00098
00099 }
00100
00101 void ObjRecSalBayes::start2()
00102 {
00103 ComplexChannel *cc =
00104 &*dynCastWeak<ComplexChannel>(itsStdBrain->getVC());
00105
00106 itsDescVec = new DescriptorVec(*itsMgr, "Descriptor Vector", "DecscriptorVec", cc);
00107
00108 itsBayesNet = new Bayes(itsDescVec->getFVSize(), 0);
00109
00110 itsDescVec->setFoveaSize(itsFoveaRadius);
00111
00112 }
00113
00114 ObjRecSalBayes::~ObjRecSalBayes()
00115 {
00116 }
00117
00118 void ObjRecSalBayes::extractFeatures(const Image<PixRGB<byte> > &img)
00119 {
00120 const int learnNumSacc = 100;
00121 Point2D<int> winner = evolveBrain(img);
00122 for (int sacc=0; sacc<learnNumSacc; sacc++)
00123 {
00124
00125 if (itsDebug){
00126 Image<PixRGB<byte> > tmpImg = img;
00127 drawCircle(tmpImg, winner, 50, PixRGB<byte>(255, 0, 0), 3);
00128 SHOWIMG(tmpImg);
00129 }
00130
00131
00132 itsDescVec->setFovea(winner);
00133
00134 if (itsDebug){
00135 SHOWIMG(itsDescVec->getFoveaImage());
00136 }
00137
00138
00139 itsDescVec->buildRawDV();
00140
00141
00142 std::vector<double> FV = itsDescVec->getFV();
00143
00144
00145 printf("%i %i %i ", sacc, winner.i, winner.j);
00146 for(uint i=0; i<FV.size(); i++)
00147 printf("%f ", FV[i]);
00148 printf("\n");
00149
00150
00151 Image<PixRGB<byte> > nullImg;
00152 winner = evolveBrain(nullImg);
00153
00154 }
00155
00156 }
00157
00158
00159
00160
00161 void ObjRecSalBayes::train(const Image<PixRGB<byte> > &img, const std::string label)
00162 {
00163
00164 const int learnNumSacc = 1;
00165 Point2D<int> winner = evolveBrain(img);
00166 for (int sacc=0; sacc<learnNumSacc; sacc++)
00167 {
00168
00169 if (itsDebug){
00170 Image<PixRGB<byte> > tmpImg = img;
00171 drawCircle(tmpImg, winner, 50, PixRGB<byte>(255, 0, 0), 3);
00172 SHOWIMG(tmpImg);
00173 }
00174
00175
00176 itsDescVec->setFovea(winner);
00177
00178 if (itsDebug){
00179 SHOWIMG(itsDescVec->getFoveaImage());
00180 }
00181
00182
00183 itsDescVec->buildRawDV();
00184
00185
00186 std::vector<double> FV = itsDescVec->getFV();
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197 printf("OD: '%s' %i %i %i ",label.c_str(), sacc, winner.i, winner.j);
00198 for(uint i=0; i<FV.size(); i++)
00199 printf("%f ", FV[i]);
00200
00201
00202 itsBayesNet->learn(FV, label.c_str());
00203
00204 Image<PixRGB<byte> > nullImg;
00205 winner = evolveBrain(nullImg);
00206
00207 }
00208
00209 }
00210
00211 void ObjRecSalBayes::finalizeTraining()
00212 {
00213
00214 }
00215
00216 std::string ObjRecSalBayes::predict(const Image<PixRGB<byte> > &img)
00217 {
00218
00219 double prob = 0, statSig = 0;
00220 Point2D<int> winner = evolveBrain(img);
00221
00222
00223 if (itsDebug){
00224 Image<PixRGB<byte> > tmpImg = img;
00225 drawCircle(tmpImg, winner, 50, PixRGB<byte>(255, 0, 0), 3);
00226 SHOWIMG(tmpImg);
00227 }
00228
00229 itsDescVec->setFovea(winner);
00230 itsDescVec->buildRawDV();
00231
00232
00233 std::vector<double> FV = itsDescVec->getFV();
00234
00235
00236
00237
00238
00239
00240
00241 int cls = -1;
00242 cls = itsBayesNet->classify(FV);
00243
00244 statSig = itsBayesNet->getStatSig(FV, 0);
00245 LINFO("Class %i prob: %f %f\n", cls, prob, statSig);
00246
00247 if (cls == -1)
00248 return std::string("NOMATCH");
00249
00250 std::string clsName(itsBayesNet->getClassName(cls));
00251
00252 return clsName;
00253 }
00254
00255
00256 Point2D<int> ObjRecSalBayes::evolveBrain(const Image<PixRGB<byte> > &img)
00257 {
00258
00259
00260 LINFO("Evolve Brain");
00261
00262 if (img.initialized())
00263 {
00264
00265 rutz::shared_ptr<SimEventInputFrame>
00266 e(new SimEventInputFrame(itsStdBrain.get(), GenericFrame(img), 0));
00267 itsQ.post(e);
00268
00269 itsDescVec->setInputImg(img);
00270 }
00271
00272 SimTime end_time = itsQ.now() + SimTime::MSECS(3.0);
00273
00274 while (itsQ.now() < end_time)
00275 {
00276 itsStdBrain->evolve(itsQ);
00277
00278
00279 if (SeC<SimEventWTAwinner> e =
00280 itsQ.check<SimEventWTAwinner>(itsStdBrain.get()))
00281 {
00282 const Point2D<int> winner = e->winner().p;
00283
00284
00285 if (itsDebug)
00286 {
00287 if (SeC<SimEventSaliencyMapOutput> smo =
00288 itsQ.check<SimEventSaliencyMapOutput>(itsStdBrain.get(), SEQ_ANY))
00289 {
00290 Image<float> img = smo->sm();
00291 SHOWIMG(rescale(img, img.getWidth()*16, img.getHeight()*16));
00292 }
00293 }
00294 while (itsQ.now() < end_time)
00295 itsQ.evolve();
00296
00297 return winner;
00298 }
00299
00300 itsQ.evolve();
00301
00302 }
00303 return Point2D<int>();
00304 }
00305
00306
00307
00308
00309
00310
00311
00312