Skip to content

Commit

Permalink
Fix wrong output for collect_list/collect_set of lists column (#1…
Browse files Browse the repository at this point in the history
…5243)

This fixes a bug in the reduction code that shows up specifically in `collect_list`/`collect_set` of lists column. In particular, the output of these reduction ops should be a list scalar holding a column that has exactly the same type structure as the input. However, when the input column contains all nulls, the output list scalar holds an empty column having wrong type structure.

Closes #14924.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - David Wendt (https://github.com/davidwendt)
  - Bradley Dice (https://github.com/bdice)

URL: #15243
  • Loading branch information
ttnghia committed Mar 13, 2024
1 parent fe9642b commit 2020ddd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
11 changes: 5 additions & 6 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -177,15 +177,14 @@ std::unique_ptr<scalar> reduce(column_view const& col,
std::move(*reduction::detail::make_empty_histogram_like(col.child(0))), true, stream, mr);
}

if (output_dtype.id() == type_id::LIST) {
if (col.type() == output_dtype) { return make_empty_scalar_like(col, stream, mr); }
// Under some circumstance, the output type will become the List of input type,
// such as: collect_list or collect_set. So, we have to handcraft the default scalar.
if (agg.kind == aggregation::COLLECT_LIST || agg.kind == aggregation::COLLECT_SET) {
auto scalar = make_list_scalar(empty_like(col)->view(), stream, mr);
scalar->set_valid_async(false, stream);
return scalar;
}
if (output_dtype.id() == type_id::STRUCT) { return make_empty_scalar_like(col, stream, mr); }

// `make_default_constructed_scalar` does not support nested type.
if (cudf::is_nested(output_dtype)) { return make_empty_scalar_like(col, stream, mr); }

auto result = make_default_constructed_scalar(output_dtype, stream, mr);
if (agg.kind == aggregation::ANY || agg.kind == aggregation::ALL) {
Expand Down
51 changes: 50 additions & 1 deletion cpp/tests/reductions/collect_ops_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -367,3 +367,52 @@ TEST_F(CollectTest, CollectEmptys)
ret = collect_set(all_nulls, cudf::make_collect_set_aggregation<cudf::reduce_aggregation>());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast<cudf::list_scalar*>(ret.get())->view());
}

TEST_F(CollectTest, CollectAllNulls)
{
using int_col = cudf::test::fixed_width_column_wrapper<int32_t>;
using namespace cudf::test::iterators;

auto const input = int_col{{0, 0, 0, 0, 0, 0}, all_nulls()};
auto const expected = int_col{};

{
auto const agg =
cudf::make_collect_list_aggregation<cudf::reduce_aggregation>(cudf::null_policy::EXCLUDE);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected,
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
{
auto const agg = cudf::make_collect_set_aggregation<cudf::reduce_aggregation>(
cudf::null_policy::EXCLUDE, cudf::null_equality::UNEQUAL, cudf::nan_equality::ALL_EQUAL);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected,
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
}

TEST_F(CollectTest, CollectAllNullsWithLists)
{
using LCW = cudf::test::lists_column_wrapper<int32_t>;
using namespace cudf::test::iterators;

// list<list<int>>
auto const input = LCW{{LCW{LCW{1, 2, 3}, LCW{4, 5, 6}}, LCW{{1, 2, 3}}}, all_nulls()};
auto const expected = cudf::empty_like(input);

{
auto const agg =
cudf::make_collect_list_aggregation<cudf::reduce_aggregation>(cudf::null_policy::EXCLUDE);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected->view(),
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
{
auto const agg = cudf::make_collect_set_aggregation<cudf::reduce_aggregation>(
cudf::null_policy::EXCLUDE, cudf::null_equality::UNEQUAL, cudf::nan_equality::ALL_EQUAL);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected->view(),
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
}

0 comments on commit 2020ddd

Please sign in to comment.