00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038 #ifndef IMAGE_LAPACK_C_DEFINED
00039 #define IMAGE_LAPACK_C_DEFINED
00040
00041 #ifdef HAVE_LAPACK
00042
00043 #include "Image/lapack.H"
00044
00045 #include "Image/Image.H"
00046 #include "Image/LinearAlgebraFlags.H"
00047 #include "Image/MatrixOps.H"
00048 #include "Image/f77lapack.H"
00049 #include "Util/log.H"
00050 #include "rutz/trace.h"
00051 #include <stdio.h>
00052
00053 namespace
00054 {
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123 void svd_lapack(Image<double>& A, Image<double>& Sigma,
00124 Image<double>& U, Image<double>& VT)
00125 {
00126 GVX_TRACE(__PRETTY_FUNCTION__);
00127
00128 char jobz = '?';
00129 f77_integer info = 0;
00130 int M = A.getWidth();
00131 int N = A.getHeight();
00132 int MNmin = std::min(M,N);
00133 f77_integer Ml = M;
00134 f77_integer Nl = N;
00135 f77_integer lda = A.getWidth();
00136
00137 if (Sigma.getSize() != MNmin)
00138 LFATAL("Sigma is not of correct size");
00139
00140 if ((U.getWidth() == M && U.getHeight() == M)
00141 && (VT.getWidth() == N && VT.getHeight() == N))
00142 jobz = 'A';
00143 else if ((U.getWidth() == M && U.getHeight() == MNmin)
00144 && (VT.getWidth() == MNmin && VT.getHeight() == N))
00145 jobz = 'S';
00146 else if (M >= N
00147 && U.getWidth() == 0
00148 && (VT.getWidth() == N && VT.getHeight() == N))
00149 jobz = 'O';
00150 else if (M < N
00151 && (U.getWidth() == M && U.getHeight() == M)
00152 && VT.getWidth() == 0)
00153 jobz = 'O';
00154 else
00155 LFATAL("U or VT is not of correct size");
00156
00157 f77_integer ldu = U.getWidth();
00158 f77_integer ldvt = VT.getWidth();
00159
00160 int liwork = 8*MNmin;
00161 Image<f77_integer> iwork(liwork, 1, NO_INIT);
00162
00163 f77_integer lwork = -1;
00164 Image<double> work(1, 1, ZEROS);
00165
00166 dgesdd_(&jobz, &Ml, &Nl, A.getArrayPtr(), &lda,
00167 Sigma.getArrayPtr(), U.getArrayPtr(), &ldu,
00168 VT.getArrayPtr(), &ldvt,
00169 work.getArrayPtr(), &lwork, iwork.getArrayPtr(),
00170 &info);
00171 lwork = int(work[0]);
00172 work.resize(lwork, 1);
00173
00174
00175 dgesdd_(&jobz, &Ml, &Nl, A.getArrayPtr(), &lda,
00176 Sigma.getArrayPtr(), U.getArrayPtr(), &ldu,
00177 VT.getArrayPtr(), &ldvt,
00178 work.getArrayPtr(), &lwork, iwork.getArrayPtr(),
00179 &info);
00180
00181 if (info != 0)
00182 LFATAL("Internal error in LAPACK: dgesdd() (info=%ld)", info);
00183 }
00184
00185
00186 void svdf_lapack(Image<float>& A, Image<float>& Sigma,
00187 Image<float>& U, Image<float>& VT)
00188 {
00189 GVX_TRACE(__PRETTY_FUNCTION__);
00190
00191 char jobz = '?';
00192 f77_integer info = 0;
00193 int M = A.getWidth();
00194 int N = A.getHeight();
00195 int MNmin = std::min(M,N);
00196 f77_integer Ml = M;
00197 f77_integer Nl = N;
00198 f77_integer lda = A.getWidth();
00199
00200 if (Sigma.getSize() != MNmin)
00201 LFATAL("Sigma is not of correct size");
00202
00203 if ((U.getWidth() == M && U.getHeight() == M)
00204 && (VT.getWidth() == N && VT.getHeight() == N))
00205 jobz = 'A';
00206 else if ((U.getWidth() == M && U.getHeight() == MNmin)
00207 && (VT.getWidth() == MNmin && VT.getHeight() == N))
00208 jobz = 'S';
00209 else if (M >= N
00210 && U.getWidth() == 0
00211 && (VT.getWidth() == N && VT.getHeight() == N))
00212 jobz = 'O';
00213 else if (M < N
00214 && (U.getWidth() == M && U.getHeight() == M)
00215 && VT.getWidth() == 0)
00216 jobz = 'O';
00217 else
00218 LFATAL("U or VT is not of correct size");
00219
00220 f77_integer ldu = U.getWidth();
00221 f77_integer ldvt = VT.getWidth();
00222
00223 int liwork = 8*MNmin;
00224 Image<f77_integer> iwork(liwork, 1, NO_INIT);
00225
00226 f77_integer lwork = -1;
00227 Image<float> work(1, 1, ZEROS);
00228
00229 sgesdd_(&jobz, &Ml, &Nl, A.getArrayPtr(), &lda,
00230 Sigma.getArrayPtr(), U.getArrayPtr(), &ldu,
00231 VT.getArrayPtr(), &ldvt,
00232 work.getArrayPtr(), &lwork, iwork.getArrayPtr(),
00233 &info);
00234 lwork = int(work[0]);
00235 work.resize(lwork, 1);
00236
00237
00238 sgesdd_(&jobz, &Ml, &Nl, A.getArrayPtr(), &lda,
00239 Sigma.getArrayPtr(), U.getArrayPtr(), &ldu,
00240 VT.getArrayPtr(), &ldvt,
00241 work.getArrayPtr(), &lwork, iwork.getArrayPtr(),
00242 &info);
00243
00244 if (info != 0)
00245 LFATAL("Internal error in LAPACK: sgesdd() (info=%ld)", info);
00246 }
00247
00248 }
00249
00250
00251 void lapack::svd(const Image<double>& A,
00252 Image<double>& U, Image<double>& S, Image<double>& V,
00253 const SvdFlag flags)
00254 {
00255 const int N = A.getWidth();
00256 const int M = A.getHeight();
00257
00258
00259 if (M < N)
00260 LFATAL("expected M >= N, got M=%d and N=%d", M, N);
00261
00262 Image<double> MMtxp = transpose(A);
00263 Image<double> SS(N, 1, ZEROS);
00264 Image<double> UUtxp(M,
00265 (flags & SVD_FULL) ? M : N,
00266 ZEROS);
00267 Image<double> VV(N, N, ZEROS);
00268
00269
00270
00271
00272
00273 svd_lapack(MMtxp, SS, UUtxp, VV);
00274
00275 U = transpose(UUtxp);
00276 V = VV;
00277 S = Image<double>(N,
00278 (flags & SVD_FULL) ? M : N,
00279 ZEROS);
00280 for (int c = 0; c < N; ++c)
00281 S[Point2D<int>(c, c)] = SS[c];
00282 }
00283
00284
00285 void lapack::svdf(const Image<float>& A,
00286 Image<float>& U, Image<float>& S, Image<float>& V,
00287 const SvdFlag flags)
00288 {
00289 const int N = A.getWidth();
00290 const int M = A.getHeight();
00291
00292
00293 if (M < N)
00294 LFATAL("expected M >= N, got M=%d and N=%d", M, N);
00295
00296 Image<float> MMtxp = transpose(A);
00297 Image<float> SS(N, 1, ZEROS);
00298 Image<float> UUtxp(M,
00299 (flags & SVD_FULL) ? M : N,
00300 ZEROS);
00301 Image<float> VV(N, N, ZEROS);
00302
00303
00304
00305
00306
00307 svdf_lapack(MMtxp, SS, UUtxp, VV);
00308
00309 U = transpose(UUtxp);
00310 V = VV;
00311 S = Image<float>(N,
00312 (flags & SVD_FULL) ? M : N,
00313 ZEROS);
00314 for (int c = 0; c < N; ++c)
00315 S[Point2D<int>(c, c)] = SS[c];
00316 }
00317
00318
00319 Image<double> lapack::dgemv(const Image<double>* v,
00320 const Image<double>* Mat)
00321 {
00322 GVX_TRACE(__PRETTY_FUNCTION__);
00323
00324 const int wv = v->getWidth(), hv = v->getHeight();
00325 const int wm = Mat->getWidth(), hm = Mat->getHeight();
00326
00327 ASSERT(wv == hm);
00328 ASSERT(hv == 1);
00329
00330 Image<double> y(wm, hv , NO_INIT);
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342 char trans = 'N';
00343 f77_integer M = Mat->getWidth(),
00344 N = Mat->getHeight(), lda = Mat->getWidth(),
00345 incv = 1, incy = 1;
00346
00347 double alpha = 1.0, beta = 0.0;
00348
00349 assert(Mat->getWidth() == y.getSize());
00350 assert(Mat->getHeight() == v->getSize());
00351
00352 {GVX_TRACE("dgemv");
00353 dgemv_(&trans, &M, &N, &alpha,
00354 Mat->getArrayPtr(), &lda,
00355 v->getArrayPtr(), &incv,
00356 &beta, y.getArrayPtr(), &incy);
00357 }
00358
00359 return y;
00360 }
00361
00362
00363 Image<float> lapack::sgemv(const Image<float>* v,
00364 const Image<float>* Mat)
00365 {
00366 GVX_TRACE(__PRETTY_FUNCTION__);
00367
00368 const int wv = v->getWidth(), hv = v->getHeight();
00369 const int wm = Mat->getWidth(), hm = Mat->getHeight();
00370
00371 ASSERT(wv == hm);
00372 ASSERT(hv == 1);
00373
00374 Image<float> y(wm, hv , NO_INIT);
00375
00376
00377
00378 char trans = 'N';
00379 f77_integer M = Mat->getWidth(),
00380 N = Mat->getHeight(), lda = Mat->getWidth(),
00381 incv = 1, incy = 1;
00382
00383 float alpha = 1.0, beta = 0.0;
00384
00385 assert(Mat->getWidth() == y.getSize());
00386 assert(Mat->getHeight() == v->getSize());
00387
00388 {GVX_TRACE("sgemv");
00389 sgemv_(&trans, &M, &N, &alpha,
00390 Mat->getArrayPtr(), &lda,
00391 v->getArrayPtr(), &incv,
00392 &beta, y.getArrayPtr(), &incy);
00393 }
00394
00395 return y;
00396 }
00397
00398
00399 Image<double> lapack::dgemm(const Image<double>* A,
00400 const Image<double>* B)
00401 {
00402 GVX_TRACE(__PRETTY_FUNCTION__);
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421 Image<double> C(B->getWidth(), A->getHeight(), NO_INIT);
00422
00423 char t = 'N';
00424
00425 f77_integer m = B->getWidth(), k = B->getHeight(), n = A->getHeight();
00426 f77_integer lda = B->getWidth(), ldb = A->getWidth(), ldc = C.getWidth();
00427 assert(B->getWidth() == C.getWidth());
00428 assert(A->getHeight() == C.getHeight());
00429 assert(B->getHeight() == A->getWidth());
00430
00431 double alpha = 1.0, beta = 0.0;
00432
00433 {GVX_TRACE("dgemm");
00434 dgemm_(&t, &t, &m, &n, &k, &alpha,
00435 B->getArrayPtr(), &lda,
00436 A->getArrayPtr(), &ldb,
00437 &beta, C.getArrayPtr(), &ldc);
00438 }
00439
00440 return C;
00441 }
00442
00443
00444 Image<float> lapack::sgemm(const Image<float>* A,
00445 const Image<float>* B)
00446 {
00447 GVX_TRACE(__PRETTY_FUNCTION__);
00448
00449
00450
00451 Image<float> C(B->getWidth(), A->getHeight(), NO_INIT);
00452
00453 char t = 'N';
00454
00455 f77_integer m = B->getWidth(), k = B->getHeight(), n = A->getHeight();
00456 f77_integer lda = B->getWidth(), ldb = A->getWidth(), ldc = C.getWidth();
00457 assert(B->getWidth() == C.getWidth());
00458 assert(A->getHeight() == C.getHeight());
00459 assert(B->getHeight() == A->getWidth());
00460
00461 float alpha = 1.0, beta = 0.0;
00462
00463 {GVX_TRACE("sgemm");
00464 sgemm_(&t, &t, &m, &n, &k, &alpha,
00465 B->getArrayPtr(), &lda,
00466 A->getArrayPtr(), &ldb,
00467 &beta, C.getArrayPtr(), &ldc);
00468 }
00469
00470 return C;
00471 }
00472
00473
00474 Image<double> lapack::dpotrf(const Image<double>* Mat)
00475 {
00476
00477 GVX_TRACE(__PRETTY_FUNCTION__);
00478
00479
00480 Image<double> A = *Mat;
00481
00482 f77_integer order = Mat->getWidth(), lda = Mat->getWidth(),
00483 flags = 0;
00484
00485 {GVX_TRACE("dpotrf");
00486
00487 char uplo = 'L';
00488
00489 dpotrf_(&uplo, &order,
00490 A.getArrayPtr(), &lda, &flags);
00491 }
00492
00493 return A;
00494 }
00495
00496 double lapack::det(const Image<double>* Mat)
00497 {
00498
00499 GVX_TRACE(__PRETTY_FUNCTION__);
00500
00501
00502
00503 Image<double> A = *Mat;
00504
00505 f77_integer m = Mat->getHeight(), n = Mat->getWidth(), lda = Mat->getWidth(), info = 0;
00506 f77_integer ipvt[std::min(m,n)];
00507
00508 {
00509 GVX_TRACE("dgetrf");
00510
00511 dgetrf_(&m,&n, A.getArrayPtr(), &lda, ipvt, &info);
00512 if (info > 0)
00513 {
00514 LINFO("ERROR: singuler matrix");
00515 return -1.0;
00516 }
00517 }
00518
00519
00520 double det = 1.0;
00521 bool neg = false;
00522 for(int i=0; i<m; i++)
00523 {
00524 det *= A.getVal(i,i);
00525 if (ipvt[i] != (i+1))
00526 neg = !neg;
00527 }
00528
00529
00530
00531 return neg ? -det : det;
00532 }
00533
00534
00535 #endif // HAVE_LAPACK
00536
00537
00538
00539
00540
00541
00542
00543 #endif // IMAGE_LAPACK_C_DEFINED