00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "SIFT/VisualObjectDB.H"
00039 #include "SIFT/Keypoint.H"
00040 #include "SIFT/KDTree.H"
00041 #include "Image/ColorOps.H"
00042 #include "Image/Image.H"
00043 #include "Image/Pixels.H"
00044 #include "Util/Timer.H"
00045 #include "Util/WorkThreadServer.H"
00046
00047 #include <fstream>
00048
00049
00050
00051 VisualObjectDB::VisualObjectDB() :
00052 itsName(), itsObjects(), itsKDTree(), itsKDindices()
00053 { }
00054
00055
00056 VisualObjectDB::~VisualObjectDB()
00057 { }
00058
00059
00060 bool VisualObjectDB::loadFrom(const std::string& fname, bool preloadImage)
00061 {
00062 const char *fn = fname.c_str();
00063 LINFO("Loading Visual Object database: '%s'...", fn);
00064
00065 std::ifstream inf(fn);
00066 if (inf.is_open() == false) { LERROR("Cannot open '%s' -- USING EMPTY", fn); return false; }
00067
00068 inf>>(*this);
00069
00070 inf.close();
00071 LINFO("Done. Loaded %u VisualObjects.", numObjects());
00072 return true;
00073 }
00074
00075
00076 bool VisualObjectDB::saveTo(const std::string& fname)
00077 {
00078 const char *fn = fname.c_str();
00079 LINFO("Saving database '%s'...", fn);
00080
00081 std::ofstream outf(fn);
00082 if (outf.is_open() == false) { LERROR("Cannot open %s for writing -- NOT SAVED", fn); return false; }
00083
00084 outf<<(*this);
00085
00086 outf.close();
00087 LINFO("Done. Saved %u VisualObjects.", numObjects());
00088 return true;
00089 }
00090
00091
00092 bool VisualObjectDB::addObject(const rutz::shared_ptr<VisualObject>& obj, bool uniqueName)
00093 {
00094 if (uniqueName)
00095 {
00096 std::string objectName = obj->getName();
00097
00098 std::vector< rutz::shared_ptr<VisualObject> >::const_iterator
00099 vo = itsObjects.begin(), stop = itsObjects.end();
00100
00101 while(vo != stop)
00102 {
00103 if ((*vo)->getName().compare(objectName) == 0) return false;
00104 ++ vo;
00105 }
00106 }
00107
00108
00109 itsObjects.push_back(obj);
00110
00111
00112 itsKDTree.reset();
00113 itsKDindices.clear();
00114
00115 return true;
00116 }
00117
00118
00119
00120 class moreVOM
00121 {
00122 public:
00123 moreVOM(const float kcoeff, const float acoeff) :
00124 itsKcoeff(kcoeff), itsAcoeff(acoeff)
00125 { }
00126
00127 bool operator()(const rutz::shared_ptr<VisualObjectMatch>& x,
00128 const rutz::shared_ptr<VisualObjectMatch>& y)
00129 { return ( x->getScore(itsKcoeff, itsAcoeff) >
00130 y->getScore(itsKcoeff, itsAcoeff) ); }
00131
00132 private:
00133 float itsKcoeff, itsAcoeff;
00134 };
00135
00136
00137 void VisualObjectDB::buildKDTree()
00138 {
00139
00140 if (itsKDTree.is_valid()) return;
00141
00142 LINFO("Building KDTree for %"ZU" objects...", itsObjects.size());
00143
00144
00145
00146 itsKDindices.clear(); uint objidx = 0U;
00147 std::vector< rutz::shared_ptr<Keypoint> > allkps;
00148
00149 std::vector< rutz::shared_ptr<VisualObject> >::const_iterator
00150 obj = itsObjects.begin(), stop = itsObjects.end();
00151
00152 while (obj != stop)
00153 {
00154 const std::vector< rutz::shared_ptr<Keypoint> >& kps = (*obj)->getKeypoints();
00155 uint kidx = 0U;
00156 std::vector< rutz::shared_ptr<Keypoint> >::const_iterator
00157 kp = kps.begin(), stopk = kps.end();
00158
00159 while(kp != stopk)
00160 {
00161
00162 allkps.push_back(*kp);
00163
00164
00165 itsKDindices.push_back(std::pair<uint, uint>(objidx, kidx));
00166
00167 ++ kp; ++kidx;
00168 }
00169 ++ obj; ++objidx;
00170 }
00171
00172
00173
00174
00175 itsKDTree.reset(new KDTree(allkps));
00176
00177 LINFO("Done. KDTree initialized with %"ZU" keypoints.", allkps.size());
00178 }
00179
00180
00181 uint VisualObjectDB::
00182 getObjectMatches(const rutz::shared_ptr<VisualObject> obj,
00183 std::vector< rutz::shared_ptr<VisualObjectMatch> >& matches,
00184 const VisualObjectMatchAlgo algo, const uint maxn,
00185 const float kcoeff, const float acoeff,
00186 const float minscore, const uint mink,
00187 const uint kthresh, const bool sortbypf)
00188 {
00189 LDEBUG("Matching '%s' against database...", obj->getName().c_str());
00190 Timer tim(1000000);
00191 matches.clear(); uint nm = 0U;
00192
00193 switch(algo)
00194 {
00195
00196 case VOMA_SIMPLE:
00197 {
00198 const uint nobj = itsObjects.size();
00199
00200 std::vector<uint> sidx;
00201 if (sortbypf) computeSortedIndices(sidx, obj);
00202
00203 for (uint i = 0; i < nobj; i ++)
00204 {
00205
00206 const uint index = sortbypf ? sidx[i] : i;
00207
00208
00209 rutz::shared_ptr<VisualObjectMatch>
00210 match(new VisualObjectMatch(obj, itsObjects[index],
00211 algo, kthresh));
00212
00213
00214 match->prune(std::max(25U, mink * 5U), mink);
00215
00216
00217 if (match->size() >= mink &&
00218 match->getScore(kcoeff, acoeff) >= minscore &&
00219 match->checkSIFTaffine())
00220 {
00221 matches.push_back(match); ++nm;
00222
00223
00224 if (nm >= maxn) break;
00225 }
00226
00227
00228
00229 }
00230 }
00231 break;
00232
00233
00234 case VOMA_KDTREE:
00235 case VOMA_KDTREEBBF:
00236 {
00237 const uint nobj = itsObjects.size();
00238 const uint kthresh2 = kthresh * kthresh;
00239
00240
00241 buildKDTree();
00242
00243
00244 const uint tstnkp = obj->numKeypoints();
00245 if (tstnkp == 0U) break;
00246 const int maxdsq = obj->getKeypoint(0)->maxDistSquared();
00247
00248
00249
00250 std::vector< std::vector<KeypointMatch> > kpm(nobj);
00251
00252
00253
00254 for (uint i = 0; i < tstnkp; i++)
00255 {
00256 int distsq1 = maxdsq, distsq2 = maxdsq;
00257 rutz::shared_ptr<Keypoint> tstkey = obj->getKeypoint(i);
00258
00259
00260 uint matchIndex = (algo == VOMA_KDTREEBBF) ?
00261 itsKDTree->nearestNeighborBBF(tstkey, 40, distsq1, distsq2) :
00262 itsKDTree->nearestNeighbor(tstkey, distsq1, distsq2);
00263
00264
00265 if (100U * distsq1 < kthresh2 * distsq2)
00266 {
00267 const uint refobjnum = itsKDindices[matchIndex].first;
00268 const uint refkpnum = itsKDindices[matchIndex].second;
00269
00270
00271
00272
00273
00274
00275 KeypointMatch m;
00276 m.refkp = tstkey;
00277 m.tstkp = itsObjects[refobjnum]->getKeypoint(refkpnum);
00278 m.distSq = distsq1;
00279 m.distSq2 = distsq2;
00280 kpm[refobjnum].push_back(m);
00281 }
00282 }
00283
00284
00285 for (uint i = 0U; i < nobj; i ++)
00286 {
00287
00288 if (kpm[i].size() >= mink)
00289 {
00290
00291 rutz::shared_ptr<VisualObjectMatch>
00292 match(new VisualObjectMatch(obj, itsObjects[i], kpm[i]));
00293
00294
00295 match->prune(std::max(25U, mink * 5U), mink);
00296
00297
00298 if (match->size() >= mink &&
00299 match->getScore(kcoeff, acoeff) >= minscore &&
00300 match->checkSIFTaffine())
00301 {
00302 matches.push_back(match); ++nm;
00303
00304
00305 if (nm >= maxn) break;
00306 }
00307 }
00308 }
00309 }
00310 break;
00311
00312
00313 default:
00314 LFATAL("Unknown matching algo %d", int(algo));
00315 }
00316
00317
00318 std::sort(matches.begin(), matches.end(), moreVOM(kcoeff, acoeff));
00319
00320 uint64 t = tim.get();
00321 LDEBUG("Found %u database object matches for '%s' in %.3fms",
00322 nm, obj->getName().c_str(), float(t) * 0.001F);
00323
00324 return nm;
00325 }
00326
00327
00328 class MatchJob : public JobServer::Job {
00329 public:
00330 MatchJob(const rutz::shared_ptr<VisualObject> obj_, const rutz::shared_ptr<VisualObject> obj2_,
00331 pthread_mutex_t *mut_, std::vector< rutz::shared_ptr<VisualObjectMatch> >& matches_,
00332 const float kcoeff_, const float acoeff_, const float minscore_, const uint mink_, const uint kthresh_) :
00333 JobServer::Job(), obj(obj_), obj2(obj2_), mut(mut_), matches(matches_), kcoeff(kcoeff_),
00334 acoeff(acoeff_), minscore(minscore_), mink(mink_), kthresh(kthresh_) { }
00335
00336 virtual ~MatchJob() { }
00337
00338 virtual void run() {
00339
00340 rutz::shared_ptr<VisualObjectMatch> match(new VisualObjectMatch(obj, obj2, VOMA_SIMPLE, kthresh));
00341
00342
00343 match->prune(std::max(25U, mink * 5U), mink);
00344
00345
00346 if (match->size() >= mink && match->getScore(kcoeff, acoeff) >= minscore && match->checkSIFTaffine())
00347 {
00348 pthread_mutex_lock(mut);
00349 matches.push_back(match);
00350 pthread_mutex_unlock(mut);
00351 }
00352 }
00353
00354 virtual const char* jobType() const { return "MatchJob"; }
00355
00356 private:
00357 const rutz::shared_ptr<VisualObject> obj;
00358 const rutz::shared_ptr<VisualObject> obj2;
00359 pthread_mutex_t *mut;
00360 std::vector< rutz::shared_ptr<VisualObjectMatch> >& matches;
00361 const float kcoeff;
00362 const float acoeff;
00363 const float minscore;
00364 const uint mink;
00365 const uint kthresh;
00366 };
00367
00368
00369 uint VisualObjectDB::
00370 getObjectMatchesParallel(const rutz::shared_ptr<VisualObject> obj,
00371 std::vector< rutz::shared_ptr<VisualObjectMatch> >& matches,
00372 const uint numthreads, const float kcoeff, const float acoeff,
00373 const float minscore, const uint mink, const uint kthresh, const bool sortbypf)
00374 {
00375 LDEBUG("Parallel matching '%s' against database...", obj->getName().c_str());
00376 Timer tim(1000000); matches.clear();
00377
00378 pthread_mutex_t mut;
00379 if (pthread_mutex_init(&mut, NULL)) PLFATAL("Error creating mutex");
00380 WorkThreadServer wts("Match Server", numthreads);
00381
00382 const uint nobj = itsObjects.size();
00383
00384 std::vector<uint> sidx;
00385 if (sortbypf) computeSortedIndices(sidx, obj);
00386
00387 for (uint i = 0; i < nobj; i ++)
00388 {
00389
00390 const uint index = sortbypf ? sidx[i] : i;
00391
00392
00393 wts.enqueueJob(rutz::make_shared(new MatchJob(obj, itsObjects[index], &mut, matches, kcoeff, acoeff,
00394 minscore, mink, kthresh)));
00395 }
00396
00397
00398 wts.flushQueue();
00399 if (pthread_mutex_destroy(&mut)) PLERROR("Error in pthread_mutex_destroy");
00400
00401
00402 std::sort(matches.begin(), matches.end(), moreVOM(kcoeff, acoeff));
00403
00404 uint64 t = tim.get();
00405 LDEBUG("Found %"ZU" database object matches for '%s' in %.3fms", matches.size(),
00406 obj->getName().c_str(), float(t) * 0.001F);
00407
00408 return matches.size();
00409 }
00410
00411
00412 void VisualObjectDB::computeSortedIndices(std::vector<uint>& indices,
00413 const rutz::shared_ptr<VisualObject>&
00414 obj) const
00415 {
00416
00417 indices.clear();
00418
00419
00420 std::vector< std::pair<double, uint> > lst;
00421
00422 for(uint i = 0; i < itsObjects.size(); i++)
00423 lst.push_back(std::pair<double, uint>(itsObjects[i]->
00424 getFeatureDistSq(obj), i));
00425
00426
00427
00428
00429
00430 std::sort(lst.begin(), lst.end());
00431
00432
00433 std::vector< std::pair<double, uint> >::const_iterator
00434 ptr = lst.begin(), stop = lst.end();
00435 while (ptr != stop)
00436 { indices.push_back(ptr->second); ++ptr; }
00437 }
00438
00439
00440 std::istream& operator>>(std::istream& is, VisualObjectDB& vdb)
00441 {
00442 vdb.createVisualObjectDB(is, vdb);
00443 return is;
00444 }
00445
00446
00447
00448 void VisualObjectDB::
00449 createVisualObjectDB(std::istream& is, VisualObjectDB& vdb, bool preloadImage)
00450 {
00451 std::string name;
00452 std::getline(is, name);
00453
00454 vdb.setName(name);
00455
00456 uint siz; is>>siz;
00457
00458 vdb.itsObjects.clear(); vdb.itsObjects.resize(siz);
00459
00460 std::vector< rutz::shared_ptr<VisualObject> >::iterator
00461 vo = vdb.itsObjects.begin(), stop = vdb.itsObjects.end();
00462
00463 while (vo != stop)
00464 {
00465 rutz::shared_ptr<VisualObject> newvo(new VisualObject());
00466 is>>(*newvo);
00467 *vo++ = newvo;
00468 }
00469 }
00470
00471
00472 std::ostream& operator<<(std::ostream& os, const VisualObjectDB& vdb)
00473 {
00474 os<<vdb.getName()<<std::endl;
00475 os<<vdb.itsObjects.size()<<std::endl;
00476
00477 std::vector< rutz::shared_ptr<VisualObject> >::const_iterator
00478 vo = vdb.itsObjects.begin(), stop = vdb.itsObjects.end();
00479
00480 while (vo != stop) { os<<(**vo); ++ vo; }
00481
00482 return os;
00483 }
00484
00485
00486
00487
00488
00489