learnvision.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #include "Channels/ChannelBase.H"
00039 #include "Channels/ChannelVisitor.H"
00040 #include "Channels/SingleChannel.H"
00041 #include "Component/ModelManager.H"
00042 #include "Image/Image.H"
00043 #include "Image/MathOps.H"
00044 #include "Image/Pixels.H"
00045 #include "Image/ShapeOps.H"
00046 #include "Image/Transforms.H"
00047 #include "Channels/RawVisualCortex.H"
00048 #include "Raster/Raster.H"
00049 #include "Util/SimTime.H"
00050 #include "Util/Types.H"
00051
00052 #include <vector>
00053 #include <cstdio>
00054
00055 namespace
00056 {
00057
00058 class CoeffLearner : public ChannelVisitor
00059 {
00060 public:
00061 CoeffLearner(const Image<byte>& dmap, const double eta,
00062 const bool softmask,
00063 const int inthresh, const int outthresh)
00064 :
00065 itsDmap(dmap),
00066 itsEta(eta),
00067 itsSoftmask(softmask),
00068 itsInThresh(inthresh),
00069 itsOutThresh(outthresh),
00070 itsAbsSumCoeffs()
00071 {
00072 itsAbsSumCoeffs.push_back(0.0);
00073 }
00074
00075 virtual ~CoeffLearner() {}
00076
00077 double absSumCoeffs() const
00078 {
00079 ASSERT(itsAbsSumCoeffs.size() == 1);
00080 return itsAbsSumCoeffs.back();
00081 }
00082
00083 virtual void visitChannelBase(ChannelBase& chan)
00084 {
00085 LFATAL("don't know how to handle %s", chan.tagName().c_str());
00086 }
00087
00088 virtual void visitSingleChannel(SingleChannel& chan)
00089 {
00090 if (chan.visualFeature() == FLICKER)
00091 {
00092
00093
00094 return;
00095 }
00096
00097 chan.killCaches();
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126 }
00127
00128 virtual void visitComplexChannel(ComplexChannel& chan)
00129 {
00130 chan.killCaches();
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151 }
00152
00153 private:
00154 const Image<byte> itsDmap;
00155 const double itsEta;
00156 const bool itsSoftmask;
00157 const int itsInThresh;
00158 const int itsOutThresh;
00159
00160 std::vector<double> itsAbsSumCoeffs;
00161 };
00162
00163
00164 class CoeffNormalizer : public ChannelVisitor
00165 {
00166 public:
00167 CoeffNormalizer(const double div)
00168 :
00169 itsDiv(div)
00170 {}
00171
00172 virtual ~CoeffNormalizer() {}
00173
00174 virtual void visitChannelBase(ChannelBase& chan)
00175 {
00176 LFATAL("don't know how to handle %s", chan.tagName().c_str());
00177 }
00178
00179 virtual void visitSingleChannel(SingleChannel& chan)
00180 {
00181
00182 }
00183
00184 virtual void visitComplexChannel(ComplexChannel& chan)
00185 {
00186 for (uint i = 0; i < chan.numChans(); ++i)
00187 chan.subChan(i)->accept(*this);
00188 }
00189
00190 private:
00191 const double itsDiv;
00192 };
00193
00194 }
00195
00196
00197
00198
00199 int main(const int argc, const char **argv)
00200 {
00201 MYLOGVERB = LOG_INFO;
00202
00203
00204 ModelManager manager("Attention Model");
00205
00206
00207 nub::soft_ref<RawVisualCortex> vcx(new RawVisualCortex(manager));
00208 manager.addSubComponent(vcx);
00209
00210
00211 if (manager.parseCommandLine(argc, argv,
00212 "<image> <targetMask> <coeffs.pmap> "
00213 "<D|N> <inthresh> <outthresh> <eta>",
00214 7, 7) == false)
00215 return(1);
00216
00217
00218 Image< PixRGB<byte> > image = Raster::ReadRGB(manager.getExtraArg(0));
00219 Image<byte> targetmask = Raster::ReadGray(manager.getExtraArg(1));
00220
00221
00222 manager.start();
00223
00224
00225 FILE *f = fopen(manager.getExtraArg(2).c_str(), "r");
00226 if (f) {
00227 fclose(f);
00228 LINFO("Loading params from %s", manager.getExtraArg(2).c_str());
00229
00230 }
00231
00232
00233 vcx->input(InputFrame::fromRgb(&image));
00234
00235
00236 bool doDistMap = false;
00237 if (manager.getExtraArg(3).c_str()[0] == 'D') doDistMap = true;
00238 int inthresh = manager.getExtraArgAs<int>(4);
00239 int outthresh = manager.getExtraArgAs<int>(5);
00240 double eta = manager.getExtraArgAs<double>(6);
00241 const double softmask = true;
00242
00243
00244
00245
00246 Image<byte> dmap;
00247 if (doDistMap) dmap = chamfer34(targetmask);
00248 else dmap = binaryReverse(targetmask, byte(255));
00249
00250 CoeffLearner l(dmap, eta, softmask, inthresh, outthresh);
00251 vcx->accept(l);
00252
00253 const double sum = l.absSumCoeffs();
00254
00255 if (sum < 0.1)
00256 {
00257 LERROR("Sum of coeffs very small (%f). Not normalized.", sum);
00258 }
00259 else
00260 {
00261 const uint nbmaps = vcx->numSubmaps();
00262 LINFO("Coeff normalization: old sum = %f, nbmaps = %d",
00263 sum, nbmaps);
00264
00265 CoeffNormalizer n(sum / double(nbmaps));
00266 vcx->accept(n);
00267 }
00268
00269
00270 LINFO("Saving params to %s", manager.getExtraArg(2).c_str());
00271
00272
00273
00274 manager.stop();
00275
00276
00277 return 0;
00278 }
00279
00280
00281
00282
00283
00284