ART1.C
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039 #include "Learn/ART1.H"
00040 #include "Util/Assert.H"
00041 #include "Util/log.H"
00042 #include <math.h>
00043 #include <fcntl.h>
00044 #include <limits>
00045 #include <string>
00046
00047 ART1::ART1(const int inputSize, const int numClasses) :
00048 itsInputSize(inputSize),
00049 itsNumClasses(numClasses)
00050 {
00051
00052
00053 itsF1.units.resize(itsInputSize);
00054 for(uint i=0; i<itsF1.units.size(); i++)
00055 itsF1.units[i].weights.resize(itsNumClasses);
00056
00057
00058 itsF2.units.resize(itsNumClasses);
00059 for(uint i=0; i<itsF2.units.size(); i++)
00060 itsF2.units[i].weights.resize(itsInputSize);
00061
00062
00063 itsA1 = 1;
00064 itsB1 = 1.5;
00065 itsC1 = 5;
00066 itsD1 = 0.9;
00067 itsL = 3;
00068 itsRho = 0.9;
00069
00070
00071 for(uint i=0; i<itsF1.units.size(); i++)
00072 for(uint j=0; j<itsF2.units.size(); j++)
00073 {
00074 itsF1.units[i].weights[j] = (itsB1 - 1) / itsD1 + 0.2;
00075 itsF2.units[j].weights[i] = itsL / (itsL - 1 + itsInputSize) - 0.1;
00076 }
00077
00078 }
00079
00080 ART1::~ART1()
00081 {
00082 }
00083
00084 void ART1::setInput(const std::vector<bool> input)
00085 {
00086 double act;
00087 for(uint i=0; i<itsF1.units.size(); i++)
00088 {
00089 act = input[i] / (1 + itsA1 * (input[i] + itsB1) + itsC1);
00090 itsF1.units[i].output = (act > 0);
00091 }
00092 }
00093
00094 int ART1::propagateToF2()
00095 {
00096
00097 double maxOut = -HUGE_VAL;
00098 int winner = -1;
00099 for (uint i=0; i<itsF2.units.size(); i++) {
00100 if (!itsF2.units[i].inhibited) {
00101 double sum = 0;
00102 for (uint j=0; j<itsF1.units.size(); j++) {
00103 sum += itsF2.units[i].weights[j] * itsF1.units[j].output;
00104 }
00105 if (sum > maxOut) {
00106 maxOut = sum;
00107 winner = i;
00108 }
00109 }
00110 itsF2.units[i].output = false;
00111 }
00112 if (winner != -1)
00113 itsF2.units[winner].output = true;
00114
00115 return winner;
00116 }
00117
00118 void ART1::propagateToF1(const std::vector<bool> input, const int winner)
00119 {
00120 for (uint i=0; i<itsF1.units.size(); i++) {
00121 double sum = itsF1.units[i].weights[winner] *
00122 itsF2.units[winner].output;
00123 double act = (input[i] + itsD1 * sum - itsB1) /
00124 (1 + itsA1 * (input[i] + itsD1 * sum) + itsC1);
00125 itsF1.units[i].output = (act > 0);
00126 }
00127 }
00128
00129
00130 void ART1::adjustWeights(const int winner)
00131 {
00132
00133 for (uint i=0; i<itsF1.units.size(); i++) {
00134 if (itsF1.units[i].output) {
00135 double mag = 0;
00136 for (uint j=0; j<itsF1.units.size(); j++)
00137 mag += itsF1.units[j].output;
00138
00139 itsF1.units[i].weights[winner] = 1;
00140 itsF2.units[winner].weights[i] = itsL / (itsL - 1 + mag);
00141 } else {
00142 itsF1.units[i].weights[winner] = 0;
00143 itsF2.units[winner].weights[i] = 0;
00144 }
00145 }
00146 }
00147
00148
00149
00150
00151 int ART1::evolveNet(std::string in)
00152 {
00153
00154 std::vector<bool> input(itsInputSize);
00155
00156 for(uint i=0; i<in.size(); i++)
00157 input[i] = (in[i] == 'O');
00158
00159 for(uint i=0; i<itsF2.units.size(); i++)
00160 itsF2.units[i].inhibited = false;
00161
00162 bool resonance = false;
00163 bool exhausted = false;
00164 int winner = -1;
00165 do
00166 {
00167 setInput(input);
00168
00169 winner = propagateToF2();
00170 if (winner != -1) {
00171 propagateToF1(input, winner);
00172
00173
00174 double magInput = 0;
00175 for(uint i=0; i<input.size(); i++)
00176 magInput += input[i];
00177
00178
00179 double magInput_ = 0;
00180 for(uint i=0; i<itsF1.units.size(); i++)
00181 magInput_ += itsF1.units[i].output;
00182
00183 if ((magInput_ / magInput) < itsRho)
00184 itsF2.units[winner].inhibited = true;
00185 else
00186 resonance = true;
00187 } else {
00188 exhausted = true;
00189 }
00190 } while (! (resonance || exhausted));
00191
00192 if (resonance)
00193 adjustWeights(winner);
00194
00195 if (exhausted)
00196 LINFO("New input and all Classes exhausted");
00197
00198 return winner;
00199
00200 }
00201
00202
00203
00204
00205
00206