Skip to content

Commit

Permalink
feat: Support reserved_ip_ranges for VPC network in Ray on Vertex clu…
Browse files Browse the repository at this point in the history
…ster

chore: Update ray prediction tests for forward compatibility

PiperOrigin-RevId: 670628417
yinghsienwu authored and copybara-github committed Sep 3, 2024
1 parent 4a528c6 commit 36a56b9
Showing 7 changed files with 45 additions and 5 deletions.
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/vertex_ray/cluster_init.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,7 @@ def create_ray_cluster(
enable_metrics_collection: Optional[bool] = True,
enable_logging: Optional[bool] = True,
psc_interface_config: Optional[resources.PscIConfig] = None,
reserved_ip_ranges: Optional[List[str]] = None,
labels: Optional[Dict[str, str]] = None,
) -> str:
"""Create a ray cluster on the Vertex AI.
@@ -126,6 +127,11 @@ def create_ray_cluster(
enable_metrics_collection: Enable Ray metrics collection for visualization.
enable_logging: Enable exporting Ray logs to Cloud Logging.
psc_interface_config: PSC-I config.
reserved_ip_ranges: A list of names for the reserved IP ranges under
the VPC network that can be used for this cluster. If set, we will
deploy the cluster within the provided IP ranges. Otherwise, the
cluster is deployed to any IP ranges under the provided VPC network.
Example: ["vertex-ai-ip-range"].
labels:
The labels with user-defined metadata to organize Ray cluster.
@@ -325,6 +331,7 @@ def create_ray_cluster(
labels=labels,
resource_runtime_spec=resource_runtime_spec,
psc_interface_config=gapic_psc_interface_config,
reserved_ip_ranges=reserved_ip_ranges,
)

location = initializer.global_config.location
Original file line number Diff line number Diff line change
@@ -43,7 +43,10 @@
import xgboost

except ModuleNotFoundError as mnfe:
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
if ray.__version__ == "2.9.3":
raise ModuleNotFoundError("XGBoost isn't installed.") from mnfe
else:
xgboost = None


def register_xgboost(
1 change: 1 addition & 0 deletions google/cloud/aiplatform/vertex_ray/util/_gapic_utils.py
Original file line number Diff line number Diff line change
@@ -150,6 +150,7 @@ def persistent_resource_to_cluster(
cluster = Cluster(
cluster_resource_name=persistent_resource.name,
network=persistent_resource.network,
reserved_ip_ranges=persistent_resource.reserved_ip_ranges,
state=persistent_resource.state.name,
labels=persistent_resource.labels,
dashboard_address=dashboard_address,
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/vertex_ray/util/resources.py
Original file line number Diff line number Diff line change
@@ -117,6 +117,11 @@ class Cluster:
managed in the Vertex API service. For Ray Job API, VPC network is
not required because cluster connection can be accessed through
dashboard address.
reserved_ip_ranges: A list of names for the reserved IP ranges under
the VPC network that can be used for this cluster. If set, we will
deploy the cluster within the provided IP ranges. Otherwise, the
cluster is deployed to any IP ranges under the provided VPC network.
Example: ["vertex-ai-ip-range"].
service_account: Service account to be used for running Ray programs on
the cluster.
state: Describes the cluster state (defined in PersistentResource.State).
@@ -140,6 +145,7 @@ class Cluster:

cluster_resource_name: str = None
network: str = None
reserved_ip_ranges: List[str] = None
service_account: str = None
state: PersistentResource.State = None
python_version: str = None
1 change: 1 addition & 0 deletions tests/unit/vertex_ray/test_cluster_init.py
Original file line number Diff line number Diff line change
@@ -384,6 +384,7 @@ def test_create_ray_cluster_2_pools_custom_images_success(
head_node_type=tc.ClusterConstants.TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
worker_node_types=tc.ClusterConstants.TEST_WORKER_NODE_TYPES_2_POOLS_CUSTOM_IMAGE,
network=tc.ProjectConstants.TEST_VPC_NETWORK,
reserved_ip_ranges=["vertex-dedicated-range"],
cluster_name=tc.ClusterConstants.TEST_VERTEX_RAY_PR_ID,
)

12 changes: 12 additions & 0 deletions tests/unit/vertex_ray/test_constants.py
Original file line number Diff line number Diff line change
@@ -51,12 +51,17 @@
from google.cloud.aiplatform_v1beta1.types.service_networking import (
PscInterfaceConfig,
)
import ray
import pytest


rovminversion = pytest.mark.skipif(
sys.version_info > (3, 10), reason="Requires python3.10 or lower"
)
# TODO(b/363340317)
xgbversion = pytest.mark.skipif(
ray.__version__ != "2.9.3", reason="Requires xgboost 1.7 or higher"
)


@dataclasses.dataclass(frozen=True)
@@ -347,6 +352,7 @@ class ClusterConstants:
),
psc_interface_config=None,
network=ProjectConstants.TEST_VPC_NETWORK,
reserved_ip_ranges=["vertex-dedicated-range"],
)
# Responses
TEST_RESOURCE_POOL_2.replica_count = 1
@@ -366,6 +372,7 @@ class ClusterConstants:
network_attachment=TEST_PSC_NETWORK_ATTACHMENT
),
network=None,
reserved_ip_ranges=None,
resource_runtime=ResourceRuntime(
access_uris={
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -386,6 +393,7 @@ class ClusterConstants:
),
),
network=ProjectConstants.TEST_VPC_NETWORK,
reserved_ip_ranges=["vertex-dedicated-range"],
resource_runtime=ResourceRuntime(
access_uris={
"RAY_DASHBOARD_URI": TEST_VERTEX_RAY_DASHBOARD_ADDRESS,
@@ -399,6 +407,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network=ProjectConstants.TEST_VPC_NETWORK,
reserved_ip_ranges=None,
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
@@ -412,6 +421,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network="",
reserved_ip_ranges="",
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS,
@@ -424,6 +434,7 @@ class ClusterConstants:
TEST_CLUSTER_CUSTOM_IMAGE = Cluster(
cluster_resource_name=TEST_VERTEX_RAY_PR_ADDRESS,
network=ProjectConstants.TEST_VPC_NETWORK,
reserved_ip_ranges=["vertex-dedicated-range"],
service_account=None,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_2_POOLS_CUSTOM_IMAGE,
@@ -438,6 +449,7 @@ class ClusterConstants:
python_version="3.10",
ray_version="2.9",
network="",
reserved_ip_ranges="",
service_account=ProjectConstants.TEST_SERVICE_ACCOUNT,
state="RUNNING",
head_node_type=TEST_HEAD_NODE_TYPE_1_POOL,
18 changes: 14 additions & 4 deletions tests/unit/vertex_ray/test_ray_prediction.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,6 @@
import numpy as np
import pytest
import ray
from ray.train import xgboost as ray_xgboost
import tensorflow as tf
import torch
import xgboost
@@ -90,9 +89,14 @@ def ray_sklearn_checkpoint():

@pytest.fixture()
def ray_xgboost_checkpoint():
model = test_prediction_utils.get_xgboost_model()
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
return checkpoint
if ray.__version__ == "2.9.3":
from ray.train import xgboost as ray_xgboost

model = test_prediction_utils.get_xgboost_model()
checkpoint = ray_xgboost.XGBoostCheckpoint.from_model(model.get_booster())
return checkpoint
else:
return None


@pytest.fixture()
@@ -374,6 +378,7 @@ def test_register_sklearnartifact_uri_not_gcs_uri_raise_error(
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")

# XGBoost Tests
@tc.xgbversion
@tc.rovminversion
def test_convert_checkpoint_to_xgboost_raise_exception(
self, ray_checkpoint_from_dict
@@ -392,6 +397,7 @@ def test_convert_checkpoint_to_xgboost_raise_exception(
"ray.train.xgboost.XGBoostCheckpoint .*"
)

@tc.xgbversion
def test_convert_checkpoint_to_xgboost_model_succeed(
self, ray_xgboost_checkpoint
) -> None:
@@ -406,6 +412,7 @@ def test_convert_checkpoint_to_xgboost_model_succeed(
y_pred = model.predict(xgboost.DMatrix(np.array([[1, 2]])))
assert y_pred[0] is not None

@tc.xgbversion
def test_register_xgboost_succeed(
self,
ray_xgboost_checkpoint,
@@ -429,6 +436,7 @@ def test_register_xgboost_succeed(
pickle_dump.assert_called_once()
gcs_utils_upload_to_gcs.assert_called_once()

@tc.xgbversion
def test_register_xgboost_initialized_succeed(
self,
ray_xgboost_checkpoint,
@@ -455,6 +463,7 @@ def test_register_xgboost_initialized_succeed(
pickle_dump.assert_called_once()
gcs_utils_upload_to_gcs.assert_called_once()

@tc.xgbversion
def test_register_xgboostartifact_uri_is_none_raise_error(
self, ray_xgboost_checkpoint
) -> None:
@@ -467,6 +476,7 @@ def test_register_xgboostartifact_uri_is_none_raise_error(
)
assert ve.match(regexp=r".*'artifact_uri' should " "start with 'gs://'.*")

@tc.xgbversion
def test_register_xgboostartifact_uri_not_gcs_uri_raise_error(
self, ray_xgboost_checkpoint
) -> None:

0 comments on commit 36a56b9

Please sign in to comment.