Skip to content
This repository was archived by the owner on Nov 24, 2018. It is now read-only.

Commit 4024e4e

Browse files
committed
Merge pull request #53 from gonum/Adddgetri
Adddgetri
2 parents 06e6694 + 7411c69 commit 4024e4e

File tree

6 files changed

+214
-2
lines changed

6 files changed

+214
-2
lines changed

cgo/lapack.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,38 @@ func (impl Implementation) Dgetrf(m, n int, a []float64, lda int, ipiv []int) (o
395395
return ok
396396
}
397397

398+
// Dgetri computes the inverse of the matrix A using the LU factorization computed
399+
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
400+
// Dgetrf and on exit contains the reciprocal of the original matrix.
401+
//
402+
// Dtrtri will not perform the inversion if the matrix is singular, and returns
403+
// a boolean indicating whether the inversion was successful.
404+
//
405+
// The C interface does not support providing temporary storage. To provide compatibility
406+
// with native, lwork == -1 will not run Dgetri but will instead write the minimum
407+
// work necessary to work[0]. If len(work) < lwork, Dgetri will panic.
408+
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
409+
checkMatrix(n, n, a, lda)
410+
if len(ipiv) < n {
411+
panic(badIpiv)
412+
}
413+
if lwork == -1 {
414+
work[0] = float64(n)
415+
return true
416+
}
417+
if lwork < n {
418+
panic(badWork)
419+
}
420+
if len(work) < lwork {
421+
panic(badWork)
422+
}
423+
ipiv32 := make([]int32, len(ipiv))
424+
for i, v := range ipiv {
425+
ipiv32[i] = int32(v) + 1 // Transform to one-indexed.
426+
}
427+
return clapack.Dgetri(n, a, lda, ipiv32)
428+
}
429+
398430
// Dgetrs solves a system of equations using an LU factorization.
399431
// The system of equations solved is
400432
// A * X = B if trans == blas.Trans

cgo/lapack_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ func TestDgetrf(t *testing.T) {
5757
testlapack.DgetrfTest(t, impl)
5858
}
5959

60+
func TestDgetri(t *testing.T) {
61+
testlapack.DgetriTest(t, impl)
62+
}
63+
6064
func TestDgetrs(t *testing.T) {
6165
testlapack.DgetrsTest(t, impl)
6266
}

native/dgetri.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package native
2+
3+
import (
4+
"github.com/gonum/blas"
5+
"github.com/gonum/blas/blas64"
6+
)
7+
8+
// Dgetri computes the inverse of the matrix A using the LU factorization computed
9+
// by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
10+
// Dgetrf and on exit contains the reciprocal of the original matrix.
11+
//
12+
// Dgetri will not perform the inversion if the matrix is singular, and returns
13+
// a boolean indicating whether the inversion was successful.
14+
//
15+
// Work is temporary storage, and lwork specifies the usable memory length.
16+
// At minimum, lwork >= n and this function will panic otherwise.
17+
// Dgetri is a blocked inversion, but the block size is limited
18+
// by the temporary space available. If lwork == -1, instead of performing Dgetri,
19+
// the optimal work length will be stored into work[0].
20+
func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
21+
checkMatrix(n, n, a, lda)
22+
if len(ipiv) < n {
23+
panic(badIpiv)
24+
}
25+
nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1)
26+
if lwork == -1 {
27+
work[0] = float64(n * nb)
28+
return true
29+
}
30+
if lwork < n {
31+
panic(badWork)
32+
}
33+
if len(work) < lwork {
34+
panic(badWork)
35+
}
36+
if n == 0 {
37+
return true
38+
}
39+
ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda)
40+
if !ok {
41+
return false
42+
}
43+
nbmin := 2
44+
ldwork := nb
45+
if nb > 1 && nb < n {
46+
iws := max(ldwork*n, 1)
47+
if lwork < iws {
48+
nb = lwork / ldwork
49+
nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1))
50+
}
51+
}
52+
bi := blas64.Implementation()
53+
// TODO(btracey): Replace this with a more row-major oriented algorithm.
54+
if nb < nbmin || nb >= n {
55+
// Unblocked code.
56+
for j := n - 1; j >= 0; j-- {
57+
for i := j + 1; i < n; i++ {
58+
work[i*ldwork] = a[i*lda+j]
59+
a[i*lda+j] = 0
60+
}
61+
if j < n {
62+
bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda)
63+
}
64+
}
65+
} else {
66+
nn := ((n - 1) / nb) * nb
67+
for j := nn; j >= 0; j -= nb {
68+
jb := min(nb, n-j)
69+
for jj := j; jj < j+jb-1; jj++ {
70+
for i := jj + 1; i < n; i++ {
71+
work[i*ldwork+(jj-j)] = a[i*lda+jj]
72+
a[i*lda+jj] = 0
73+
}
74+
}
75+
if j+jb < n {
76+
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda)
77+
bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda)
78+
}
79+
}
80+
}
81+
for j := n - 2; j >= 0; j-- {
82+
jp := ipiv[j]
83+
if jp != j {
84+
bi.Dswap(n, a[j:], lda, a[jp:], lda)
85+
}
86+
}
87+
return true
88+
}

native/dtrtri.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import (
99
// into a. This is the BLAS level 3 version of the algorithm which builds upon
1010
// Dtrti2 to operate on matrix blocks instead of only individual columns.
1111
//
12-
// Dtrti returns whether the matrix a is singular or whether it's not singular.
13-
// If the matrix is singular the inversion is not performed.
12+
// Dtrtri will not perform the inversion if the matrix is singular, and returns
13+
// a boolean indicating whether the inversion was successful.
1414
func (impl Implementation) Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) (ok bool) {
1515
checkMatrix(n, n, a, lda)
1616
if uplo != blas.Upper && uplo != blas.Lower {

native/lapack_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ func TestDgeqrf(t *testing.T) {
3636
testlapack.DgeqrfTest(t, impl)
3737
}
3838

39+
func TestDgetri(t *testing.T) {
40+
testlapack.DgetriTest(t, impl)
41+
}
42+
3943
func TestDgetf2(t *testing.T) {
4044
testlapack.Dgetf2Test(t, impl)
4145
}

testlapack/dgetri.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package testlapack
2+
3+
import (
4+
"math"
5+
"math/rand"
6+
"testing"
7+
8+
"github.com/gonum/blas"
9+
"github.com/gonum/blas/blas64"
10+
)
11+
12+
type Dgetrier interface {
13+
Dgetrfer
14+
Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool
15+
}
16+
17+
func DgetriTest(t *testing.T, impl Dgetrier) {
18+
bi := blas64.Implementation()
19+
for _, test := range []struct {
20+
n, lda int
21+
}{
22+
{5, 0},
23+
{5, 8},
24+
{45, 0},
25+
{45, 50},
26+
{65, 0},
27+
{65, 70},
28+
{150, 0},
29+
{150, 250},
30+
} {
31+
n := test.n
32+
lda := test.lda
33+
if lda == 0 {
34+
lda = n
35+
}
36+
// Generate a random well conditioned matrix
37+
perm := rand.Perm(n)
38+
a := make([]float64, n*lda)
39+
for i := 0; i < n; i++ {
40+
a[i*lda+perm[i]] = 1
41+
}
42+
for i := range a {
43+
a[i] += 0.01 * rand.Float64()
44+
}
45+
aCopy := make([]float64, len(a))
46+
copy(aCopy, a)
47+
ipiv := make([]int, n)
48+
// Compute LU decomposition.
49+
impl.Dgetrf(n, n, a, lda, ipiv)
50+
// Compute inverse.
51+
work := make([]float64, 1)
52+
impl.Dgetri(n, a, lda, ipiv, work, -1)
53+
work = make([]float64, int(work[0]))
54+
lwork := len(work)
55+
56+
ok := impl.Dgetri(n, a, lda, ipiv, work, lwork)
57+
if !ok {
58+
t.Errorf("Unexpected singular matrix.")
59+
}
60+
61+
// Check that A(inv) * A = I.
62+
ans := make([]float64, len(a))
63+
bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
64+
isEye := true
65+
for i := 0; i < n; i++ {
66+
for j := 0; j < n; j++ {
67+
if i == j {
68+
// This tolerance is so high because computing matrix inverses
69+
// is very unstable.
70+
if math.Abs(ans[i*lda+j]-1) > 2e-2 {
71+
isEye = false
72+
}
73+
} else {
74+
if math.Abs(ans[i*lda+j]) > 2e-2 {
75+
isEye = false
76+
}
77+
}
78+
}
79+
}
80+
if !isEye {
81+
t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)