Skip to content

Commit

Permalink
Merge pull request #23148 from charris/backport-23079
Browse files Browse the repository at this point in the history
BUG: Fix integer / float scalar promotion
  • Loading branch information
charris committed Feb 2, 2023
2 parents 62af62a + c115e12 commit 8dfa47d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
25 changes: 9 additions & 16 deletions numpy/core/src/umath/scalarmath.c.src
Expand Up @@ -1179,6 +1179,11 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
* (Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble)*4,
* (Half, Float, Double, LongDouble)*3#
* #NAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG)*12,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE)*4,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3#
* #type = (npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong)*12,
* (npy_half, npy_float, npy_double, npy_longdouble,
Expand All @@ -1202,24 +1207,12 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
* (npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble)*4,
* (npy_half, npy_float, npy_double, npy_longdouble)*3#
* #oname = (byte, ubyte, short, ushort, int, uint,
* long, ulong, longlong, ulonglong)*11,
* double*10,
* (half, float, double, longdouble,
* cfloat, cdouble, clongdouble)*4,
* (half, float, double, longdouble)*3#
* #OName = (Byte, UByte, Short, UShort, Int, UInt,
* Long, ULong, LongLong, ULongLong)*11,
* Double*10,
* (Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble)*4,
* (Half, Float, Double, LongDouble)*3#
* #ONAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG)*11,
* DOUBLE*10,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE)*4,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3#
*/
#define IS_@name@
/* drop the "true_" from "true_divide" for floating point warnings: */
Expand All @@ -1234,7 +1227,7 @@ static PyObject *
@name@_@oper@(PyObject *a, PyObject *b)
{
PyObject *ret;
@otype@ arg1, arg2, other_val;
@type@ arg1, arg2, other_val;

/*
* Check if this operation may be considered forward. Note `is_forward`
Expand Down Expand Up @@ -1263,7 +1256,7 @@ static PyObject *
PyObject *other = is_forward ? b : a;

npy_bool may_need_deferring;
conversion_result res = convert_to_@oname@(
conversion_result res = convert_to_@name@(
other, &other_val, &may_need_deferring);
if (res == CONVERSION_ERROR) {
return NULL; /* an error occurred (should never happen) */
Expand Down Expand Up @@ -1305,7 +1298,7 @@ static PyObject *
*/
return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b);
case CONVERT_PYSCALAR:
if (@ONAME@_setitem(other, (char *)&other_val, NULL) < 0) {
if (@NAME@_setitem(other, (char *)&other_val, NULL) < 0) {
return NULL;
}
break;
Expand Down Expand Up @@ -1345,7 +1338,7 @@ static PyObject *
#if @twoout@
int retstatus = @name@_ctype_@oper@(arg1, arg2, &out, &out2);
#else
int retstatus = @oname@_ctype_@oper@(arg1, arg2, &out);
int retstatus = @name@_ctype_@oper@(arg1, arg2, &out);
#endif

#if @fperr@
Expand Down
59 changes: 47 additions & 12 deletions numpy/core/tests/test_scalarmath.py
Expand Up @@ -75,17 +75,7 @@ def test_leak(self):
np.add(1, 1)


@pytest.mark.slow
@settings(max_examples=10000, deadline=2000)
@given(sampled_from(reasonable_operators_for_scalars),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()))
def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
"""
This is a thorough test attempting to cover important promotion paths
and ensuring that arrays and scalars stay as aligned as possible.
However, if it creates troubles, it should maybe just be removed.
"""
def check_ufunc_scalar_equivalence(op, arr1, arr2):
scalar1 = arr1[()]
scalar2 = arr2[()]
assert isinstance(scalar1, np.generic)
Expand All @@ -95,6 +85,11 @@ def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
comp_ops = {operator.ge, operator.gt, operator.le, operator.lt}
if op in comp_ops and (np.isnan(scalar1) or np.isnan(scalar2)):
pytest.xfail("complex comp ufuncs use sort-order, scalars do not.")
if op == operator.pow and arr2.item() in [-1, 0, 0.5, 1, 2]:
# array**scalar special case can have different result dtype
# (Other powers may have issues also, but are not hit here.)
# TODO: It would be nice to resolve this issue.
pytest.skip("array**2 can have incorrect/weird result dtype")

# ignore fpe's since they may just mismatch for integers anyway.
with warnings.catch_warnings(), np.errstate(all="ignore"):
Expand All @@ -107,7 +102,47 @@ def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
op(scalar1, scalar2)
else:
scalar_res = op(scalar1, scalar2)
assert_array_equal(scalar_res, res)
assert_array_equal(scalar_res, res, strict=True)


@pytest.mark.slow
@settings(max_examples=10000, deadline=2000)
@given(sampled_from(reasonable_operators_for_scalars),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()))
def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
"""
This is a thorough test attempting to cover important promotion paths
and ensuring that arrays and scalars stay as aligned as possible.
However, if it creates troubles, it should maybe just be removed.
"""
check_ufunc_scalar_equivalence(op, arr1, arr2)


@pytest.mark.slow
@given(sampled_from(reasonable_operators_for_scalars),
hynp.scalar_dtypes(), hynp.scalar_dtypes())
def test_array_scalar_ufunc_dtypes(op, dt1, dt2):
# Same as above, but don't worry about sampling weird values so that we
# do not have to sample as much
arr1 = np.array(2, dtype=dt1)
arr2 = np.array(3, dtype=dt2) # some power do weird things.

check_ufunc_scalar_equivalence(op, arr1, arr2)


@pytest.mark.parametrize("fscalar", [np.float16, np.float32])
def test_int_float_promotion_truediv(fscalar):
# Promotion for mixed int and float32/float16 must not go to float64
i = np.int8(1)
f = fscalar(1)
expected = np.result_type(i, f)
assert (i / f).dtype == expected
assert (f / i).dtype == expected
# But normal int / int true division goes to float64:
assert (i / i).dtype == np.dtype("float64")
# For int16, result has to be ast least float32 (takes ufunc path):
assert (np.int16(1) / f).dtype == np.dtype("float32")


class TestBaseMath:
Expand Down

0 comments on commit 8dfa47d

Please sign in to comment.