Skip to content

Commit

Permalink
Implement to_device() for python bindings, add unit tests (#1611)
Browse files Browse the repository at this point in the history
* Implement to_device() for python bindings
  • Loading branch information
ozancaglayan committed Feb 8, 2024
1 parent 8c6715e commit 4c7b956
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ jobs:
name: Set up QEMU

- name: Build wheels
uses: pypa/cibuildwheel@v2.15.0
uses: pypa/cibuildwheel@v2.16.5
with:
package-dir: python
output-dir: python/wheelhouse
Expand Down
23 changes: 23 additions & 0 deletions python/cpp/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ namespace ctranslate2 {
.value("int32", DataType::INT32)
;

py::enum_<Device>(m, "Device")
.value("cpu", Device::CPU)
.value("cuda", Device::CUDA)
;

py::class_<StorageView>(
m, "StorageView",
R"pbdoc(
Expand Down Expand Up @@ -206,6 +211,24 @@ namespace ctranslate2 {
A new ``StorageView`` instance.
)pbdoc")

.def("to_device",
[](const StorageView& view, Device device) {
ScopedDeviceSetter device_setter(view.device(), view.device_index());
StorageView converted = view.to(device);
synchronize_stream(view.device());
return converted;
},
py::arg("device"),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Converts the storage to another device.
Arguments:
device: The device to copy the data to.
Returns:
A new ``StorageView`` instance.
)pbdoc")
;
}

Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AsyncScoringResult,
AsyncTranslationResult,
DataType,
Device,
Encoder,
EncoderForwardOutput,
ExecutionStats,
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_storage_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def test_storageview_cpu(dtype, name):
assert test_utils.array_equal(x, y)


def test_storageview_to_device():
x = np.ones(10, dtype="float32")
cx = ctranslate2.StorageView.from_array(x)

assert cx.device == "cpu"
assert cx.dtype == ctranslate2.DataType.float32

cpu_x = cx.to_device(ctranslate2.Device.cpu)
assert test_utils.array_equal(x, np.array(cpu_x))


@test_utils.require_cuda
def test_storageview_cuda():
import torch
Expand Down Expand Up @@ -74,6 +85,21 @@ def test_storageview_cuda():
_assert_same_array(s.__cuda_array_interface__, y.__cuda_array_interface__)


@test_utils.require_cuda
def test_storageview_cuda_to_device():
x = np.ones(10, dtype="float32")
# convert to cuda tensor
cuda_x = ctranslate2.StorageView.from_array(x).to_device(ctranslate2.Device.cuda)
assert cuda_x.device == "cuda"

# modify original tensor and convert back
x *= 2
cpu_x = np.array(cuda_x.to_device(ctranslate2.Device.cpu))

assert cpu_x.dtype == x.dtype
assert x.sum() == 2 * cpu_x.sum()


def test_storageview_conversion():
x = np.ones((2, 4), dtype=np.float32)
s = ctranslate2.StorageView.from_array(x)
Expand Down

0 comments on commit 4c7b956

Please sign in to comment.