Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT:linalg.det:Return scalars for singleton inputs #18763

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 12 additions & 5 deletions scipy/linalg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,8 @@ def det(a, overwrite_a=False, check_finite=True):
det : (...) float or complex
Determinant of `a`. For stacked arrays, a scalar is returned for each
(m, m) slice in the last two dimensions of the input. For example, an
input of shape (p, q, m, m) will produce a result of shape (p, q).
input of shape (p, q, m, m) will produce a result of shape (p, q). If
all dimensions are 1 a scalar is returned regardless of ndim.

Notes
-----
Expand Down Expand Up @@ -1066,11 +1067,17 @@ def det(a, overwrite_a=False, check_finite=True):

# Scalar case
if a1.shape[-2:] == (1, 1):
if a1.dtype.char in 'dD':
return np.squeeze(a1)
# Either ndarray with spurious singletons or a single element
if max(*a1.shape) > 1:
temp = np.squeeze(a1)
if a1.dtype.char in 'dD':
return temp
else:
return (temp.astype('d') if a1.dtype.char == 'f' else
temp.astype('D'))
else:
return (np.squeeze(a1).astype('d') if a1.dtype.char == 'f' else
np.squeeze(a1).astype('D'))
return (np.float64(a1.item()) if a1.dtype.char in 'fd' else
np.complex128(a1.item()))

# Then check overwrite permission
if not _datacopied(a1, a): # "a" still alive through "a1"
Expand Down
17 changes: 17 additions & 0 deletions scipy/linalg/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,23 @@ class TestDet:
def setup_method(self):
self.rng = np.random.default_rng(1680305949878959)

def test_1x1_all_singleton_dims(self):
a = np.array([[1]])
deta = det(a)
assert deta.dtype.char == 'd'
assert np.isscalar(deta)
assert deta == 1.
a = np.array([[[[1]]]], dtype='f')
deta = det(a)
assert deta.dtype.char == 'd'
assert np.isscalar(deta)
assert deta == 1.
a = np.array([[[1 + 3.j]]], dtype=np.complex64)
deta = det(a)
assert deta.dtype.char == 'D'
assert np.isscalar(deta)
assert deta == 1.+3.j

def test_1by1_stacked_input_output(self):
a = self.rng.random([4, 5, 1, 1], dtype=np.float32)
deta = det(a)
Expand Down