00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef CHANNELS_OPTIMALGAINS_C_DEFINED
00039 #define CHANNELS_OPTIMALGAINS_C_DEFINED
00040
00041 #include "Channels/OptimalGains.H"
00042
00043 #include "Channels/ChannelFacets.H"
00044 #include "Channels/ComplexChannel.H"
00045 #include "Channels/SingleChannel.H"
00046 #include "Component/ParamMap.H"
00047 #include "Image/MathOps.H"
00048 #include "Image/ShapeOps.H"
00049 #include "Util/sformat.H"
00050
00051 #include <vector>
00052
00053
00054 OptimalGainsFinder::OptimalGainsFinder(const Image<byte>& targetMask,
00055 const Image<byte>& distractorMask,
00056 rutz::shared_ptr<ParamMap> pmap,
00057 const bool doMax)
00058 :
00059 itsTargetMask(targetMask),
00060 itsDistractorMask(distractorMask),
00061 itsPmap(pmap),
00062 itsDoMax(doMax)
00063 { }
00064
00065
00066 OptimalGainsFinder::~OptimalGainsFinder()
00067 { }
00068
00069
00070 void OptimalGainsFinder::visitChannelBase(ChannelBase& chan)
00071 { LFATAL("don't know how to handle %s", chan.tagName().c_str()); }
00072
00073
00074 void OptimalGainsFinder::visitSingleChannel(SingleChannel& chan)
00075 {
00076
00077 rutz::shared_ptr<ChannelFacetGainSingle> gfacet;
00078 if (chan.hasFacet<ChannelFacetGainSingle>())
00079 gfacet = chan.getFacet<ChannelFacetGainSingle>();
00080 else
00081 { gfacet.reset(new ChannelFacetGainSingle(chan)); chan.setFacet(gfacet); }
00082
00083
00084
00085
00086
00087
00088 const uint num = chan.numSubmaps();
00089 float sumSNR = 0.0f, SNR[num];
00090
00091 Image<byte> tmap, dmap;
00092 if (itsTargetMask.initialized())
00093 tmap = rescale(itsTargetMask, chan.getMapDims());
00094 if (itsDistractorMask.initialized())
00095 dmap = rescale(itsDistractorMask, chan.getMapDims());
00096
00097 for (uint idx = 0; idx < num; idx ++)
00098 {
00099 const Image<float> submap = chan.getSubmap(idx);
00100 float sT = 0.0f, sD = 0.0f;
00101 float junk1, junk2, junk3;
00102
00103
00104 if (itsDoMax)
00105 {
00106
00107
00108 if (tmap.initialized())
00109 getMaskedMinMax(submap, tmap, junk1, sT, junk2, junk3);
00110
00111 if (dmap.initialized())
00112 getMaskedMinMax(submap, dmap, junk1, sD, junk2, junk3);
00113 }
00114 else
00115 {
00116
00117
00118 if (tmap.initialized())
00119 getMaskedMinMaxAvg(submap, tmap, junk1, junk2, sT);
00120
00121 if (dmap.initialized())
00122 getMaskedMinMaxAvg(submap, dmap, junk1, junk2, sD);
00123 }
00124
00125 SNR[idx] = (sT + OPTIGAIN_BG_FIRING) / (sD + OPTIGAIN_BG_FIRING);
00126 sumSNR += SNR[idx];
00127
00128
00129 itsPmap->putDoubleParam(sformat("salienceT(%d)", idx), sT);
00130 itsPmap->putDoubleParam(sformat("salienceD(%d)", idx), sD);
00131
00132 uint c = 0, s = 0; chan.getLevelSpec().indexToCS(idx, c, s);
00133 LDEBUG("%s(%d,%d): sT=%f, sD=%f", chan.tagName().c_str(), c, s, sT, sD);
00134 }
00135 sumSNR /= num;
00136
00137
00138 for (uint idx = 0; idx < num; idx ++)
00139 {
00140 const float gain = SNR[idx] / sumSNR;
00141 uint clev = 0, slev = 0; chan.getLevelSpec().indexToCS(idx, clev, slev);
00142 LINFO("%s(%d,%d): gain = %f, SNR = %f", chan.tagName().c_str(),
00143 clev, slev, gain, SNR[idx]);
00144 gfacet->setVal(idx, gain);
00145 }
00146
00147
00148
00149 chan.killCaches();
00150 (void) chan.getOutput();
00151 }
00152
00153
00154 void OptimalGainsFinder::visitComplexChannel(ComplexChannel& chan)
00155 {
00156
00157 rutz::shared_ptr<ChannelFacetGainComplex> gfacet;
00158 if (chan.hasFacet<ChannelFacetGainComplex>())
00159 gfacet = chan.getFacet<ChannelFacetGainComplex>();
00160 else
00161 { gfacet.reset(new ChannelFacetGainComplex(chan)); chan.setFacet(gfacet); }
00162
00163
00164 const uint num = chan.numChans();
00165 rutz::shared_ptr<ParamMap> pmapsave = itsPmap;
00166 for (uint idx = 0; idx < num; idx ++)
00167 {
00168
00169 itsPmap.reset(new ParamMap());
00170 chan.subChan(idx)->accept(*this);
00171
00172
00173 itsPmap->putIntParam("subchanidx", idx);
00174
00175
00176 pmapsave->putSubpmap(chan.subChan(idx)->tagName(), itsPmap);
00177 }
00178 itsPmap.swap(pmapsave);
00179
00180
00181
00182
00183
00184 float sumSNR = 0.0f, SNR[num];
00185
00186 Image<byte> tmap, dmap;
00187 if (itsTargetMask.initialized())
00188 tmap = rescale(itsTargetMask, chan.getMapDims());
00189 if (itsDistractorMask.initialized())
00190 dmap = rescale(itsDistractorMask, chan.getMapDims());
00191
00192
00193 for (uint idx = 0; idx < num; idx ++)
00194 {
00195
00196 const Image<float> submap = chan.subChan(idx)->getOutput();
00197 float sT = 0.0f, sD = 0.0f;
00198 float junk1, junk2, junk3;
00199
00200
00201 if (itsDoMax)
00202 {
00203
00204
00205 if (tmap.initialized())
00206 getMaskedMinMax(submap, tmap, junk1, sT, junk2, junk3);
00207
00208 if (dmap.initialized())
00209 getMaskedMinMax(submap, dmap, junk1, sD, junk2, junk3);
00210 }
00211 else
00212 {
00213
00214
00215 if (tmap.initialized())
00216 getMaskedMinMaxAvg(submap, tmap, junk1, junk2, sT);
00217
00218 if (dmap.initialized())
00219 getMaskedMinMaxAvg(submap, dmap, junk1, junk2, sD);
00220 }
00221
00222 SNR[idx] = (sT + OPTIGAIN_BG_FIRING) / (sD + OPTIGAIN_BG_FIRING);
00223 sumSNR += SNR[idx];
00224
00225
00226 itsPmap->putDoubleParam(sformat("salienceT(%d)", idx), sT);
00227 itsPmap->putDoubleParam(sformat("salienceD(%d)", idx), sD);
00228
00229 LDEBUG("%s: sT=%f, sD=%f", chan.subChan(idx)->tagName().c_str(), sT, sD);
00230 }
00231 sumSNR /= num;
00232
00233
00234 for (uint idx = 0; idx < num; idx ++)
00235 {
00236 const float gain = SNR[idx] / sumSNR;
00237 LINFO("%s: gain = %f, SNR = %f", chan.subChan(idx)->tagName().c_str(),
00238 gain, SNR[idx]);
00239 gfacet->setVal(idx, gain);
00240 }
00241
00242
00243
00244 chan.killCaches();
00245 (void) chan.getOutput();
00246 }
00247
00248
00249 rutz::shared_ptr<ParamMap> OptimalGainsFinder::pmap() const
00250 { return itsPmap; }
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260 #endif // CHANNELS_OPTIMALGAINS_C_DEFINED