From f1340c73ab0f80a1f1090bea6b86498cc8b60a1e Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 18 Jan 2023 13:45:04 +0100 Subject: [PATCH] CI Adapt handling of discarded fused typed memoryview (#25425) Co-authored-by: Olivier Grisel --- sklearn/datasets/_svmlight_format_fast.pyx | 5 +++-- sklearn/datasets/_svmlight_format_io.py | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/_svmlight_format_fast.pyx b/sklearn/datasets/_svmlight_format_fast.pyx index cb0a3885ed6cb..b578584e5ac47 100644 --- a/sklearn/datasets/_svmlight_format_fast.pyx +++ b/sklearn/datasets/_svmlight_format_fast.pyx @@ -187,6 +187,7 @@ def _dump_svmlight_file( bint y_is_sp, ): cdef bint X_is_integral + cdef bint query_id_is_not_empty = query_id.size > 0 X_is_integral = X.dtype.kind == "i" if X_is_integral: value_pattern = "%d:%d" @@ -198,7 +199,7 @@ def _dump_svmlight_file( label_pattern = "%.16g" line_pattern = "%s" - if query_id is not None: + if query_id_is_not_empty: line_pattern += " qid:%d" line_pattern += " %s\n" @@ -246,7 +247,7 @@ def _dump_svmlight_file( else: labels_str = label_pattern % y[i,0] - if query_id is not None: + if query_id_is_not_empty: feat = (labels_str, query_id[i], s) else: feat = (labels_str, s) diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index 16aae0de4f2b0..2a141e1732ff7 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -506,7 +506,13 @@ def dump_svmlight_file( if hasattr(X, "sort_indices"): X.sort_indices() - if query_id is not None: + if query_id is None: + # NOTE: query_id is passed to Cython functions using a fused type on query_id. + # Yet as of Cython>=3.0, memory views can't be None otherwise the runtime + # would not known which concrete implementation to dispatch the Python call to. + # TODO: simplify interfaces and implementations in _svmlight_format_fast.pyx. + query_id = np.array([], dtype=np.int32) + else: query_id = np.asarray(query_id) if query_id.shape[0] != y.shape[0]: raise ValueError(