Skip to content

Commit

Permalink
MAINT:linalg.det:Return scalars for singleton inputs (scipy#18763)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilayn authored and tylerjereddy committed Jun 28, 2023
1 parent a1c6f99 commit 0760bab
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
17 changes: 12 additions & 5 deletions scipy/linalg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,8 @@ def det(a, overwrite_a=False, check_finite=True):
det : (...) float or complex
Determinant of `a`. For stacked arrays, a scalar is returned for each
(m, m) slice in the last two dimensions of the input. For example, an
input of shape (p, q, m, m) will produce a result of shape (p, q).
input of shape (p, q, m, m) will produce a result of shape (p, q). If
all dimensions are 1 a scalar is returned regardless of ndim.
Notes
-----
Expand Down Expand Up @@ -1066,11 +1067,17 @@ def det(a, overwrite_a=False, check_finite=True):

# Scalar case
if a1.shape[-2:] == (1, 1):
if a1.dtype.char in 'dD':
return np.squeeze(a1)
# Either ndarray with spurious singletons or a single element
if max(*a1.shape) > 1:
temp = np.squeeze(a1)
if a1.dtype.char in 'dD':
return temp
else:
return (temp.astype('d') if a1.dtype.char == 'f' else
temp.astype('D'))
else:
return (np.squeeze(a1).astype('d') if a1.dtype.char == 'f' else
np.squeeze(a1).astype('D'))
return (np.float64(a1.item()) if a1.dtype.char in 'fd' else
np.complex128(a1.item()))

# Then check overwrite permission
if not _datacopied(a1, a): # "a" still alive through "a1"
Expand Down
17 changes: 17 additions & 0 deletions scipy/linalg/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,23 @@ class TestDet:
def setup_method(self):
self.rng = np.random.default_rng(1680305949878959)

def test_1x1_all_singleton_dims(self):
a = np.array([[1]])
deta = det(a)
assert deta.dtype.char == 'd'
assert np.isscalar(deta)
assert deta == 1.
a = np.array([[[[1]]]], dtype='f')
deta = det(a)
assert deta.dtype.char == 'd'
assert np.isscalar(deta)
assert deta == 1.
a = np.array([[[1 + 3.j]]], dtype=np.complex64)
deta = det(a)
assert deta.dtype.char == 'D'
assert np.isscalar(deta)
assert deta == 1.+3.j

def test_1by1_stacked_input_output(self):
a = self.rng.random([4, 5, 1, 1], dtype=np.float32)
deta = det(a)
Expand Down

0 comments on commit 0760bab

Please sign in to comment.