mtx.cc

00001 
00003 
00004 //
00005 // Copyright (c) 2001-2004 California Institute of Technology
00006 // Copyright (c) 2004-2007 University of Southern California
00007 // Rob Peters <rjpeters at usc dot edu>
00008 //
00009 // created: Mon Mar 12 12:39:12 2001
00010 // commit: $Id: mtx.cc 10065 2007-04-12 05:54:56Z rjpeters $
00011 // $HeadURL: file:///lab/rjpeters/svnrepo/code/trunk/groovx/src/mtx/mtx.cc $
00012 //
00013 // --------------------------------------------------------------------
00014 //
00015 // This file is part of GroovX.
00016 //   [http://ilab.usc.edu/rjpeters/groovx/]
00017 //
00018 // GroovX is free software; you can redistribute it and/or modify it
00019 // under the terms of the GNU General Public License as published by
00020 // the Free Software Foundation; either version 2 of the License, or
00021 // (at your option) any later version.
00022 //
00023 // GroovX is distributed in the hope that it will be useful, but
00024 // WITHOUT ANY WARRANTY; without even the implied warranty of
00025 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00026 // General Public License for more details.
00027 //
00028 // You should have received a copy of the GNU General Public License
00029 // along with GroovX; if not, write to the Free Software Foundation,
00030 // Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
00031 //
00033 
00034 #ifndef GROOVX_PKGS_MTX_MTX_CC_UTC20050626084022_DEFINED
00035 #define GROOVX_PKGS_MTX_MTX_CC_UTC20050626084022_DEFINED
00036 
00037 #include "mtx.h"
00038 
00039 #include "rutz/cstrstream.h"
00040 #include "rutz/error.h"
00041 #include "rutz/fstring.h"
00042 #include "rutz/sfmt.h"
00043 
00044 #include <algorithm>
00045 #include <iostream>
00046 #include <iomanip>
00047 #include <numeric>
00048 #include <sstream>
00049 #include <vector>
00050 
00051 #include "rutz/trace.h"
00052 #include "rutz/debug.h"
00053 GVX_DBG_REGISTER
00054 
00055 using rutz::fstring;
00056 
00057 namespace
00058 {
00059   inline void domemswap(double* buf1, double* buf2,
00060                         double* tempbuf1, size_t nelems) throw()
00061   {
00062     memcpy(tempbuf1, buf1, nelems*sizeof(double));
00063     memcpy(buf1, buf2, nelems*sizeof(double));
00064     memcpy(buf2, tempbuf1, nelems*sizeof(double));
00065   }
00066 
00067   inline void memswap(double* buf1, double* buf2, size_t nelems)
00068   {
00069     const size_t BUFSIZE = 512;
00070     if (nelems <= BUFSIZE)
00071       {
00072         double swapbuffer[BUFSIZE];
00073         domemswap(buf1, buf2, swapbuffer, nelems);
00074       }
00075     else
00076       {
00077         double* tempbuf1 = new double[nelems];
00078         domemswap(buf1, buf2, tempbuf1, nelems);
00079         delete [] tempbuf1;
00080       }
00081   }
00082 }
00083 
00084 namespace range_checking
00085 {
00086   void raise_exception(const fstring& msg, const char* f, int ln);
00087 }
00088 
00089 void range_checking::raise_exception(const fstring& msg,
00090                                      const char* f, int ln)
00091 {
00092   dbg_print_nl(3, msg);
00093   const fstring errmsg =
00094     rutz::sfmt("range check failed in file '%s' at line #%d: %s",
00095                f, ln, msg.c_str());
00096   throw rutz::error(errmsg, SRC_POS);
00097 }
00098 
00099 void range_checking::geq(const void* x, const void* lim,
00100                          const char* f, int ln)
00101 {
00102   if (x>=lim) ; // OK
00103   else raise_exception("geq: pointer range error", f, ln);
00104 }
00105 
00106 void range_checking::lt(const void* x, const void* lim,
00107                         const char* f, int ln)
00108 {
00109   if (x<lim) ; // OK
00110   else raise_exception("less: pointer range error", f, ln);
00111 }
00112 
00113 void range_checking::leq(const void* x, const void* lim,
00114                          const char* f, int ln)
00115 {
00116   if (x<=lim) ; // OK
00117   else raise_exception("leq: pointer range error", f, ln);
00118 }
00119 
00120 void range_checking::in_half_open(const void* x,
00121                                   const void* llim, const void* ulim,
00122                                   const char* f, int ln)
00123 {
00124   if (x>=llim && x<ulim) ; // OK
00125   else raise_exception("in_half_open: pointer range error", f, ln);
00126 }
00127 
00128 void range_checking::in_full_open(const void* x,
00129                                   const void* llim, const void* ulim,
00130                                   const char* f, int ln)
00131 {
00132   if (x>=llim && x<=ulim) ; // OK
00133   else raise_exception("in_full_open: pointer range error", f, ln);
00134 }
00135 
00136 void range_checking::geq(int x, int lim, const char* f, int ln)
00137 {
00138   if (x>=lim) ; // OK
00139   else raise_exception(rutz::sfmt("geq: integer range error "
00140                                   "%d !>= %d", x, lim),
00141                        f, ln);
00142 }
00143 
00144 void range_checking::lt(int x, int lim, const char* f, int ln)
00145 {
00146   if (x<lim) ; // OK
00147   else raise_exception(rutz::sfmt("less: integer range error "
00148                                   "%d !< %d", x, lim),
00149                        f, ln);
00150 }
00151 
00152 void range_checking::leq(int x, int lim, const char* f, int ln)
00153 {
00154   if (x<=lim) ; // OK
00155   else raise_exception(rutz::sfmt("leq: integer range error "
00156                                   "%d !<= %d", x, lim),
00157                        f, ln);
00158 }
00159 
00160 void range_checking::in_half_open(int x,
00161                                   int llim, int ulim,
00162                                   const char* f, int ln)
00163 {
00164   if (x>=llim && x<ulim) ; // OK
00165   else raise_exception(rutz::sfmt("in_half_open: integer range error "
00166                                   "%d !in [%d, %d[", x, llim, ulim),
00167                        f, ln);
00168 }
00169 
00170 void range_checking::in_full_open(int x,
00171                                   int llim, int ulim,
00172                                   const char* f, int ln)
00173 {
00174   if (x>=llim && x<=ulim) ; // OK
00175   else raise_exception(rutz::sfmt("in_full_open: integer range error "
00176                                   "%d !in [%d, %d]", x, llim, ulim),
00177                        f, ln);
00178 }
00179 
00180 
00182 //
00183 // slice member definitions
00184 //
00186 
00187 slice slice::operator()(const index_range& rng) const
00188 {
00189 GVX_TRACE("slice::operator");
00190   RC_in_half_open(rng.begin(), 0, m_nelems);
00191   RC_geq(rng.count(), 0);
00192   RC_leq(rng.end(), m_nelems);
00193 
00194   return slice(m_data_source, storage_offset(rng.begin()),
00195                m_stride, rng.count());
00196 }
00197 
00198 void slice::print(std::ostream& s) const
00199 {
00200 GVX_TRACE("slice::print");
00201   for (mtx_const_iter iter = begin(); iter.has_more(); ++iter)
00202     {
00203       s << std::setw(12) << std::setprecision(7) << double(*iter);
00204     }
00205   s << std::endl;
00206 }
00207 
00208 void slice::print_stdout() const
00209 {
00210 GVX_TRACE("slice::print_stdout");
00211   print(std::cout);
00212 }
00213 
00214 double slice::sum() const
00215 {
00216 GVX_TRACE("slice::sum");
00217   double s = 0.0;
00218   for (mtx_const_iter i = begin(); i.has_more(); ++i)
00219     s += *i;
00220   return s;
00221 }
00222 
00223 double slice::min() const
00224 {
00225 GVX_TRACE("slice::min");
00226   mtx_const_iter i = begin();
00227   double mn = *i;
00228   for (; i.has_more(); ++i)
00229     if (*i < mn) mn = *i;
00230   return mn;
00231 }
00232 
00233 double slice::max() const
00234 {
00235 GVX_TRACE("slice::max");
00236   mtx_const_iter i = begin();
00237   double m = *i;
00238   for (; i.has_more(); ++i)
00239     if (*i > m) m = *i;
00240   return m;
00241 }
00242 
00243 namespace
00244 {
00245   struct val_index
00246   {
00247     double val;
00248     unsigned int index;
00249 
00250     val_index(double v=0.0) : val(v) {}
00251 
00252     bool operator<(const val_index& v2) const { return val < v2.val; }
00253   };
00254 }
00255 
00256 mtx slice::get_sort_order() const
00257 {
00258 GVX_TRACE("slice::get_sort_order");
00259 
00260   std::vector<val_index> buf(this->begin(), this->end());
00261 
00262   for (unsigned int i = 0; i < buf.size(); ++i)
00263     {
00264       buf[i].index = i;
00265     }
00266 
00267   std::sort(buf.begin(), buf.end());
00268 
00269   mtx index = mtx::uninitialized(1, this->nelems());
00270 
00271   for (int i = 0; i < nelems(); ++i)
00272     {
00273       GVX_ASSERT(buf[i].index < static_cast<unsigned int>(nelems()));
00274       index.at(0,i) = buf[i].index;
00275     }
00276 
00277   return index;
00278 }
00279 
00280 bool slice::operator==(const slice& other) const
00281 {
00282 GVX_TRACE("slice::operator==(const slice&)");
00283   if (m_nelems != other.m_nelems) return false;
00284 
00285   for (mtx_const_iter a = this->begin(), b = other.begin();
00286        a.has_more();
00287        ++a, ++b)
00288     if (*a != *b) return false;
00289 
00290   return true;
00291 }
00292 
00293 void slice::sort()
00294 {
00295 GVX_TRACE("slice::sort");
00296   std::sort(begin_nc(), end_nc());
00297 }
00298 
00299 void slice::reorder(const mtx& index_)
00300 {
00301 GVX_TRACE("slice::reorder");
00302   mtx index(index_.as_column());
00303 
00304   if (index.mrows() != nelems())
00305     throw rutz::error("dimension mismatch in slice::reorder", SRC_POS);
00306 
00307   mtx neworder = mtx::uninitialized(this->nelems(), 1);
00308 
00309   for (int i = 0; i < nelems(); ++i)
00310     neworder.at(i,0) = (*this)[int(index.at(i,0))];
00311 
00312   *this = neworder.column(0);
00313 }
00314 
00315 slice& slice::operator+=(const slice& other)
00316 {
00317 GVX_TRACE("slice::operator+=(const slice&)");
00318   if (m_nelems != other.nelems())
00319     throw rutz::error("dimension mismatch in slice::operator+=", SRC_POS);
00320 
00321   mtx_const_iter rhs = other.begin();
00322 
00323   for (mtx_iter lhs = begin_nc(); lhs.has_more(); ++lhs, ++rhs)
00324     *lhs += *rhs;
00325 
00326   return *this;
00327 }
00328 
00329 slice& slice::operator-=(const slice& other)
00330 {
00331 GVX_TRACE("slice::operator-=(const slice&)");
00332   if (m_nelems != other.nelems())
00333     throw rutz::error("dimension mismatch in slice::operator-=", SRC_POS);
00334 
00335   mtx_const_iter rhs = other.begin();
00336 
00337   for (mtx_iter lhs = begin_nc(); lhs.has_more(); ++lhs, ++rhs)
00338     *lhs -= *rhs;
00339 
00340   return *this;
00341 }
00342 
00343 slice& slice::operator=(double val)
00344 {
00345 GVX_TRACE("slice::operator=(double)");
00346   for (mtx_iter itr = begin_nc(); itr.has_more(); ++itr)
00347     *itr = val;
00348 
00349   return *this;
00350 }
00351 
00352 slice& slice::operator=(const slice& other)
00353 {
00354 GVX_TRACE("slice::operator=(const slice&)");
00355   if (m_nelems != other.nelems())
00356     throw rutz::error("dimension mismatch in slice::operator=", SRC_POS);
00357 
00358   mtx_const_iter rhs = other.begin();
00359 
00360   for (mtx_iter lhs = begin_nc(); lhs.has_more(); ++lhs, ++rhs)
00361     *lhs = *rhs;
00362 
00363   return *this;
00364 }
00365 
00366 slice& slice::operator=(const mtx& other)
00367 {
00368 GVX_TRACE("slice::operator=(const mtx&)");
00369   if (m_nelems != other.nelems())
00370     throw rutz::error("dimension mismatch in slice::operator=", SRC_POS);
00371 
00372   int i = 0;
00373   for (mtx_iter lhs = begin_nc(); lhs.has_more(); ++lhs, ++i)
00374     *lhs = other.at(i);
00375 
00376   return *this;
00377 }
00378 
00379 void mtx_specs::swap(mtx_specs& other)
00380 {
00381   std::swap(m_shape, other.m_shape);
00382   std::swap(m_rowstride, other.m_rowstride);
00383   std::swap(m_offset, other.m_offset);
00384 }
00385 
00386 mtx_specs mtx_specs::as_shape(const mtx_shape& s) const
00387 {
00388 GVX_TRACE("mtx_specs::as_shape");
00389   if (s.nelems() != this->nelems())
00390     {
00391       const fstring msg =
00392         rutz::sfmt("as_shape(): dimension mismatch: "
00393                    "current nelems == %d; requested %dx%d",
00394                    nelems(), s.mrows(), s.ncols());
00395       throw rutz::error(msg, SRC_POS);
00396     }
00397 
00398   if (m_rowstride != mrows())
00399     throw rutz::error("as_shape(): cannot reshape a submatrix", SRC_POS);
00400 
00401   mtx_specs result = *this;
00402   result.m_shape = s;
00403   result.m_rowstride = s.mrows();
00404 
00405   return result;
00406 }
00407 
00408 void mtx_specs::select_rows(const row_index_range& rng)
00409 {
00410 GVX_TRACE("mtx_specs::select_rows");
00411   if (rng.begin() < 0)
00412     throw rutz::error("select_rows(): row index must be >= 0", SRC_POS);
00413 
00414   if (rng.count() <= 0)
00415     throw rutz::error("select_rows(): number of rows must be > 0", SRC_POS);
00416 
00417   if (rng.end() > mrows())
00418     throw rutz::error("select_rows(): upper row index out of range", SRC_POS);
00419 
00420   m_offset += rng.begin();
00421   m_shape = mtx_shape(rng.count(), ncols());
00422 }
00423 
00424 void mtx_specs::select_cols(const col_index_range& rng)
00425 {
00426 GVX_TRACE("mtx_specs::select_cols");
00427   if (rng.begin() < 0)
00428     throw rutz::error("select_cols(): column index must be >= 0", SRC_POS);
00429 
00430   if (rng.count() <= 0)
00431     throw rutz::error("select_cols(): number of columns must be > 0", SRC_POS);
00432 
00433   if (rng.end() > ncols())
00434     throw rutz::error("select_cols(): upper column index out of range", SRC_POS);
00435 
00436   m_offset += rng.begin()*m_rowstride;
00437   m_shape = mtx_shape(mrows(), rng.count());
00438 }
00439 
00441 //
00442 // mtx_base member definitions
00443 //
00445 
00446 template <class Data>
00447 void mtx_base<Data>::swap(mtx_base& other)
00448 {
00449   mtx_specs::swap(other);
00450   m_data.swap(other.m_data);
00451 }
00452 
00453 template <class Data>
00454 mtx_base<Data>::mtx_base(const mtx_base& other) :
00455   mtx_specs(other),
00456   m_data(other.m_data)
00457 {}
00458 
00459 template <class Data>
00460 mtx_base<Data>::mtx_base(int mrows, int ncols, const Data& data) :
00461   mtx_specs(mrows, ncols),
00462   m_data(data)
00463 {}
00464 
00465 template <class Data>
00466 mtx_base<Data>::mtx_base(const mtx_specs& specs, const Data& data) :
00467   mtx_specs(specs),
00468   m_data(data)
00469 {}
00470 
00471 template <class Data>
00472 mtx_base<Data>::~mtx_base() {}
00473 
00474 template class mtx_base<data_holder>;
00475 
00476 template class mtx_base<data_ref_holder>;
00477 
00479 //
00480 // sub_mtx_ref member definitions
00481 //
00483 
00484 sub_mtx_ref& sub_mtx_ref::operator=(const sub_mtx_ref& other)
00485 {
00486 GVX_TRACE("sub_mtx_ref::operator=(const sub_mtx_ref&)");
00487   if (this->nelems() != other.nelems())
00488     throw rutz::error("sub_mtx_ref::operator=(): dimension mismatch",
00489                       SRC_POS);
00490 
00491   std::copy(other.colmaj_begin(), other.colmaj_end(),
00492             this->colmaj_begin_nc());
00493 
00494   return *this;
00495 }
00496 
00497 sub_mtx_ref& sub_mtx_ref::operator=(const mtx& other)
00498 {
00499 GVX_TRACE("sub_mtx_ref::operator=(const mtx&)");
00500   if (this->nelems() != other.nelems())
00501     throw rutz::error("sub_mtx_ref::operator=(): dimension mismatch",
00502                       SRC_POS);
00503 
00504   std::copy(other.colmaj_begin(), other.colmaj_end(),
00505             this->colmaj_begin_nc());
00506 
00507   return *this;
00508 }
00509 
00511 //
00512 // mtx member definitions
00513 //
00515 
00516 mtx mtx::colmaj_copy_of(const double* data, int mrows, int ncols)
00517 {
00518 GVX_TRACE("mtx::colmaj_copy_of");
00519 
00520   return mtx(mtx_shape(mrows, ncols),
00521              data_holder(const_cast<double*>(data),
00522                          mrows, ncols, COPY));
00523 }
00524 
00525 mtx mtx::colmaj_borrow_from(double* data, int mrows, int ncols)
00526 {
00527 GVX_TRACE("mtx::colmaj_borrow_from");
00528 
00529   return mtx(mtx_shape(mrows, ncols),
00530              data_holder(data, mrows, ncols, BORROW));
00531 }
00532 
00533 mtx mtx::colmaj_refer_to(double* data, int mrows, int ncols)
00534 {
00535 GVX_TRACE("mtx::colmaj_refer_to");
00536 
00537   return mtx(mtx_shape(mrows, ncols),
00538              data_holder(data, mrows, ncols, REFER));
00539 }
00540 
00541 mtx mtx::zeros(const mtx_shape& s)
00542 {
00543   return mtx(s, data_holder(s.mrows(), s.ncols(), ZEROS));
00544 }
00545 
00546 mtx mtx::uninitialized(const mtx_shape& s)
00547 {
00548   return mtx(s, data_holder(s.mrows(), s.ncols(), NO_INIT));
00549 }
00550 
00551 mtx mtx::from_stream(std::istream& s)
00552 {
00553 GVX_TRACE("mtx::from_stream");
00554 
00555   fstring buf;
00556   int mrows = -1;
00557   int ncols = -1;
00558 
00559   s >> buf;
00560   if (buf != "mrows")
00561     throw rutz::error(rutz::sfmt("parse error while scanning mtx "
00562                                  "from stream: expected 'mrows', got '%s'",
00563                                  buf.c_str()),
00564                       SRC_POS);
00565 
00566   s >> mrows;
00567   if (mrows < 0)
00568     throw rutz::error("parse error while scanning mtx "
00569                       "from stream: expected mrows>=0", SRC_POS);
00570 
00571   s >> buf;
00572   if (buf != "ncols")
00573     throw rutz::error(rutz::sfmt("parse error while scanning mtx "
00574                                  "from stream: expected 'ncols', got '%s'",
00575                                  buf.c_str()),
00576                       SRC_POS);
00577 
00578   s >> ncols;
00579   if (ncols < 0)
00580     throw rutz::error("parse error while scanning mtx "
00581                       "from stream: expected ncols>=0", SRC_POS);
00582 
00583   mtx result = mtx::zeros(mrows, ncols);
00584 
00585   for (int r = 0; r < mrows; ++r)
00586     for (int c = 0; c < ncols; ++c)
00587       {
00588         if (s.eof())
00589           throw rutz::error("premature EOF while scanning mtx "
00590                             "from stream", SRC_POS);
00591         double d = 0.0;
00592         s >> d;
00593         result.at(r,c) = d;
00594       }
00595 
00596   if (s.fail())
00597     throw rutz::error("error while scanning mtx from stream", SRC_POS);
00598 
00599   return result;
00600 }
00601 
00602 mtx mtx::from_string(const char* s)
00603 {
00604 GVX_TRACE("mtx::from_string");
00605 
00606   rutz::imemstream ms(s);
00607   return mtx::from_stream(ms);
00608 }
00609 
00610 
00611 const mtx& mtx::empty_mtx()
00612 {
00613 GVX_TRACE("mtx::empty_mtx");
00614   static mtx* m = 0;
00615   if (m == 0)
00616     {
00617       m = new mtx(mtx::zeros(0,0));
00618     }
00619   return *m;
00620 }
00621 
00622 mtx::mtx(const slice& s) :
00623   Base(s.nelems(), 1, data_holder(s.nelems(), 1, NO_INIT))
00624 {
00625 GVX_TRACE("mtx::mtx");
00626   std::copy(s.begin(), s.end(), this->colmaj_begin_nc());
00627 }
00628 
00629 mtx::~mtx() {}
00630 
00631 void mtx::resize(int mrows_new, int ncols_new)
00632 {
00633 GVX_TRACE("mtx::resize");
00634   if (mrows() == mrows_new && ncols() == ncols_new)
00635     return;
00636   else
00637     {
00638       mtx newsize = mtx::zeros(mrows_new, ncols_new);
00639       this->swap(newsize);
00640     }
00641 }
00642 
00643 mtx mtx::contig() const
00644 {
00645 GVX_TRACE("mtx::contig");
00646   if (mrows() == rowstride())
00647     return *this;
00648 
00649   mtx result = mtx::uninitialized(this->shape());
00650 
00651   std::copy(this->colmaj_begin(), this->colmaj_end(),
00652             result.colmaj_begin_nc());
00653 
00654   return result;
00655 }
00656 
00657 namespace
00658 {
00659   void format_mtx(const mtx& m,
00660                   std::ostream& s,
00661                   const char* mtx_name,
00662                   bool trailing_newline)
00663   {
00664     if (mtx_name != 0 && mtx_name[0] != '\0')
00665       s << '[' << mtx_name << "] ";
00666 
00667     s << "mrows " << m.mrows() << " ncols " << m.ncols();
00668     for(int i = 0; i < m.mrows(); ++i)
00669       {
00670         s << '\n';
00671         for(int j = 0; j < m.ncols(); ++j)
00672           s << ' '
00673             << std::setw(18)
00674             << std::setprecision(17)
00675             << m.at(i,j);
00676       }
00677 
00678     if (trailing_newline)
00679       s << '\n';
00680   }
00681 }
00682 
00683 void mtx::print(std::ostream& s, const char* mtx_name) const
00684 {
00685 GVX_TRACE("mtx::print");
00686   format_mtx(*this, s, mtx_name, true);
00687 }
00688 
00689 void mtx::print_stdout() const
00690 {
00691 GVX_TRACE("mtx::print_stdout");
00692   format_mtx(*this, std::cout, 0, true);
00693 }
00694 
00695 void mtx::print_stdout_named(const char* mtx_name) const
00696 {
00697 GVX_TRACE("mtx::print_stdout_named");
00698   format_mtx(*this, std::cout, mtx_name, true);
00699 }
00700 
00701 rutz::fstring mtx::as_string() const
00702 {
00703 GVX_TRACE("mtx::as_string");
00704   std::ostringstream oss;
00705 
00706   format_mtx(*this, oss, 0, false);
00707 
00708   return rutz::fstring(oss.str().c_str());
00709 }
00710 
00711 void mtx::scan(std::istream& s)
00712 {
00713 GVX_TRACE("mtx::scan");
00714 
00715   *this = mtx::from_stream(s);
00716 }
00717 
00718 void mtx::scan_string(const char* s)
00719 {
00720 GVX_TRACE("mtx::scan_string");
00721 
00722   *this = mtx::from_string(s);
00723 }
00724 
00725 void mtx::reorder_rows(const mtx& index_)
00726 {
00727 GVX_TRACE("mtx::reorder_rows");
00728 
00729   mtx index(index_.as_column());
00730 
00731   if (index.mrows() != mrows())
00732     throw rutz::error("dimension mismatch in mtx::reorder_rows",
00733                       SRC_POS);
00734 
00735   mtx neworder = mtx::uninitialized(this->shape());
00736 
00737   for (int r = 0; r < mrows(); ++r)
00738     neworder.row(r) = row(int(index.at(r,0)));
00739 
00740   *this = neworder;
00741 }
00742 
00743 void mtx::reorder_columns(const mtx& index_)
00744 {
00745 GVX_TRACE("mtx::reorder_columns");
00746 
00747   mtx index(index_.as_column());
00748 
00749   if (index.mrows() != ncols())
00750     throw rutz::error("dimension mismatch in mtx::reorder_columns",
00751                       SRC_POS);
00752 
00753   mtx neworder = mtx::uninitialized(this->shape());
00754 
00755   for (int c = 0; c < ncols(); ++c)
00756     neworder.column(c) = column(int(index.at(c,0)));
00757 
00758   *this = neworder;
00759 }
00760 
00761 void mtx::swap_columns(int c1, int c2)
00762 {
00763 GVX_TRACE("mtx::swap_columns");
00764 
00765   if (c1 == c2) return;
00766 
00767   memswap(address_nc(0,c1), address_nc(0,c2), mrows());
00768 }
00769 
00770 mtx mtx::mean_row() const
00771 {
00772 GVX_TRACE("mtx::mean_row");
00773 
00774   mtx res = mtx::uninitialized(1, ncols());
00775 
00776   mtx_iter resiter = res.row(0).begin_nc();
00777 
00778   for (int c = 0; c < ncols(); ++c, ++resiter)
00779     *resiter = column(c).mean();
00780 
00781   return res;
00782 }
00783 
00784 mtx mtx::mean_column() const
00785 {
00786 GVX_TRACE("mtx::mean_column");
00787 
00788   mtx res = mtx::uninitialized(mrows(), 1);
00789 
00790   mtx_iter resiter = res.column(0).begin_nc();
00791 
00792   for (int r = 0; r < mrows(); ++r, ++resiter)
00793     *resiter = row(r).mean();
00794 
00795   return res;
00796 }
00797 
00798 mtx::const_iterator mtx::find_min() const
00799 {
00800 GVX_TRACE("mtx::find_min");
00801 
00802   if (nelems() == 0)
00803     throw rutz::error("find_min(): the matrix must be non-empty",
00804                       SRC_POS);
00805 
00806   return std::min_element(begin(), end());
00807 }
00808 
00809 mtx::const_iterator mtx::find_max() const
00810 {
00811 GVX_TRACE("mtx::find_max");
00812 
00813   if (nelems() == 0)
00814     throw rutz::error("find_max(): the matrix must be non-empty",
00815                       SRC_POS);
00816 
00817   return std::max_element(begin(), end());
00818 }
00819 
00820 double mtx::min() const
00821 {
00822 GVX_TRACE("mtx::min");
00823 
00824   if (nelems() == 0)
00825     throw rutz::error("min(): the matrix must be non-empty",
00826                       SRC_POS);
00827 
00828   return *(std::min_element(colmaj_begin(), colmaj_end()));
00829 }
00830 
00831 double mtx::max() const
00832 {
00833 GVX_TRACE("mtx::max");
00834 
00835   if (nelems() == 0)
00836     throw rutz::error("max(): the matrix must be non-empty",
00837                       SRC_POS);
00838 
00839   return *(std::max_element(colmaj_begin(), colmaj_end()));
00840 }
00841 
00842 double mtx::sum() const
00843 {
00844 GVX_TRACE("mtx::sum");
00845   return std::accumulate(begin(), end(), 0.0);
00846 }
00847 
00848 mtx& mtx::operator+=(const mtx& other)
00849 {
00850 GVX_TRACE("mtx::operator+=(const mtx&)");
00851   if (ncols() != other.ncols())
00852     throw rutz::error("dimension mismatch in mtx::operator+=",
00853                       SRC_POS);
00854 
00855   for (int i = 0; i < ncols(); ++i)
00856     column(i) += other.column(i);
00857 
00858   return *this;
00859 }
00860 
00861 mtx& mtx::operator-=(const mtx& other)
00862 {
00863 GVX_TRACE("mtx::operator-=(const mtx&)");
00864   if (ncols() != other.ncols())
00865     throw rutz::error("dimension mismatch in mtx::operator-=",
00866                       SRC_POS);
00867 
00868   for (int i = 0; i < ncols(); ++i)
00869     column(i) -= other.column(i);
00870 
00871   return *this;
00872 }
00873 
00874 bool mtx::operator==(const mtx& other) const
00875 {
00876 GVX_TRACE("mtx::operator==(const mtx&)");
00877   if ( (mrows() != other.mrows()) || (ncols() != other.ncols()) )
00878     return false;
00879   for (int c = 0; c < ncols(); ++c)
00880     if ( column(c) != other.column(c) ) return false;
00881   return true;
00882 }
00883 
00884 void mtx::VMmul_assign(const slice& vec, const mtx& mtx,
00885                        slice& result)
00886 {
00887 GVX_TRACE("mtx::VMmul_assign");
00888 
00889   // e.g mrows == vec.nelems == 3   ncols == 4
00890   //
00891   //               | e11  e12  e13  e14 |
00892   // [w1 w2 w3] *  | e21  e22  e23  e24 | =
00893   //               | e31  e32  e33  e34 |
00894   //
00895   //
00896   // [ w1*e11+w2*e21+w3*e31  w1*e12+w2*e22+w3*e32  ... ]
00897 
00898   if ( (vec.nelems() != mtx.mrows()) ||
00899        (result.nelems() != mtx.ncols()) )
00900     throw rutz::error("dimension mismatch in mtx::VMmul_assign",
00901                       SRC_POS);
00902 
00903   mtx_const_iter veciter = vec.begin();
00904 
00905   mtx_iter result_itr = result.begin_nc();
00906 
00907   for (int col = 0; col < mtx.ncols(); ++col, ++result_itr)
00908     *result_itr = inner_product(veciter, mtx.column_iter(col));
00909 }
00910 
00911 void mtx::assign_MMmul(const mtx& m1, const mtx& m2)
00912 {
00913 GVX_TRACE("mtx::assign_MMmul");
00914   if ( (m1.ncols() != m2.mrows()) ||
00915        (this->ncols() != m2.ncols()) )
00916     throw rutz::error("dimension mismatch in mtx::VMmul_assign",
00917                       SRC_POS);
00918 
00919   for (int n = 0; n < mrows(); ++n)
00920     {
00921       mtx_iter row_element = this->row_iter(n);
00922 
00923       mtx_const_iter veciter = m1.row_iter(n);
00924 
00925       for (int col = 0; col < m2.ncols(); ++col, ++row_element)
00926         *row_element = inner_product(veciter, m2.column_iter(col));
00927     }
00928 }
00929 
00930 namespace
00931 {
00932   template <class Op>
00933   mtx unary_op(const mtx& src, Op op)
00934   {
00935     mtx result = mtx::uninitialized(src.shape());
00936 
00937     std::transform(src.colmaj_begin(), src.colmaj_end(),
00938                    result.colmaj_begin_nc(),
00939                    op);
00940 
00941     return result;
00942   }
00943 }
00944 
00945 mtx operator+(const mtx& m, double x)
00946 {
00947   return unary_op(m, std::bind2nd(std::plus<double>(), x));
00948 }
00949 
00950 mtx operator-(const mtx& m, double x)
00951 {
00952   return unary_op(m, std::bind2nd(std::minus<double>(), x));
00953 }
00954 
00955 mtx operator*(const mtx& m, double x)
00956 {
00957   return unary_op(m, std::bind2nd(std::multiplies<double>(), x));
00958 }
00959 
00960 mtx operator/(const mtx& m, double x)
00961 {
00962   return unary_op(m, std::bind2nd(std::divides<double>(), x));
00963 }
00964 
00965 
00966 namespace
00967 {
00968   template <class Op>
00969   mtx binary_op(const mtx& m1, const mtx& m2, Op op)
00970   {
00971     if (! m1.same_size(m2) )
00972       throw rutz::error("dimension mismatch in binary_op(mtx, mtx)",
00973                         SRC_POS);
00974 
00975     mtx result = mtx::uninitialized(m1.shape());
00976 
00977     std::transform(m1.colmaj_begin(), m1.colmaj_end(),
00978                    m2.colmaj_begin(),
00979                    result.colmaj_begin_nc(),
00980                    op);
00981 
00982     return result;
00983   }
00984 }
00985 
00986 mtx operator+(const mtx& m1, const mtx& m2)
00987 {
00988 GVX_TRACE("operator+(mtx, mtx)");
00989   return binary_op(m1, m2, std::plus<double>());
00990 }
00991 
00992 mtx operator-(const mtx& m1, const mtx& m2)
00993 {
00994 GVX_TRACE("operator-(mtx, mtx)");
00995   return binary_op(m1, m2, std::minus<double>());
00996 }
00997 
00998 mtx arr_mul(const mtx& m1, const mtx& m2)
00999 {
01000 GVX_TRACE("arr_mul(mtx, mtx)");
01001   return binary_op(m1, m2, std::multiplies<double>());
01002 }
01003 
01004 mtx arr_div(const mtx& m1, const mtx& m2)
01005 {
01006 GVX_TRACE("arr_div(mtx, mtx)");
01007   return binary_op(m1, m2, std::divides<double>());
01008 }
01009 
01010 mtx min(const mtx& m1, const mtx& m2)
01011 {
01012 GVX_TRACE("min(mtx, mtx)");
01013   return binary_op(m1, m2, dash::min());
01014 }
01015 
01016 mtx max(const mtx& m1, const mtx& m2)
01017 {
01018 GVX_TRACE("max(mtx, mtx)");
01019   return binary_op(m1, m2, dash::max());
01020 }
01021 
01022 static const char __attribute__((used)) vcid_groovx_pkgs_mtx_mtx_cc_utc20050626084022[] = "$Id: mtx.cc 10065 2007-04-12 05:54:56Z rjpeters $ $HeadURL: file:
01023 #endif // !GROOVX_PKGS_MTX_MTX_CC_UTC20050626084022_DEFINED

The software described here is Copyright (c) 1998-2005, Rob Peters.
This page was generated Wed Dec 3 06:49:39 2008 by Doxygen version 1.5.5.