Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure all command and service calls raise when disconnected #840

Merged
merged 1 commit into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def cover_command(
tilt: float | None = None,
stop: bool = False,
) -> None:
connection = self._get_connection()
req = CoverCommandRequest(key=key)
apiv = self.api_version
if TYPE_CHECKING:
Expand All @@ -951,7 +952,7 @@ def cover_command(
elif position == 0.0:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
self._get_connection().send_message(req)
connection.send_message(req)

def fan_command(
self,
Expand Down Expand Up @@ -1058,6 +1059,7 @@ def climate_command( # pylint: disable=too-many-branches
custom_preset: str | None = None,
target_humidity: float | None = None,
) -> None:
connection = self._get_connection()
req = ClimateCommandRequest(key=key)
if mode is not None:
req.has_mode = True
Expand Down Expand Up @@ -1096,7 +1098,7 @@ def climate_command( # pylint: disable=too-many-branches
if target_humidity is not None:
req.has_target_humidity = True
req.target_humidity = target_humidity
self._get_connection().send_message(req)
connection.send_message(req)

def number_command(self, key: int, state: float) -> None:
self._get_connection().send_message(NumberCommandRequest(key=key, state=state))
Expand Down Expand Up @@ -1172,6 +1174,7 @@ def text_command(self, key: int, state: str) -> None:
def execute_service(
self, service: UserService, data: ExecuteServiceDataType
) -> None:
connection = self._get_connection()
req = ExecuteServiceRequest(key=service.key)
args = []
apiv = self.api_version
Expand All @@ -1196,7 +1199,7 @@ def execute_service(
# pylint: disable=no-member
req.args.extend(args)

self._get_connection().send_message(req)
connection.send_message(req)

def _request_image(self, *, single: bool = False, stream: bool = False) -> None:
self._get_connection().send_message(
Expand Down
52 changes: 52 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,3 +2280,55 @@ async def test_api_version_after_connection_closed(
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None


@pytest.mark.asyncio
async def test_calls_after_connection_closed(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test calls after connection close should raise APIConnectionError."""
client, connection, transport, protocol = api_client
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None
service = UserService(
name="my_service",
key=1,
args=[],
)
with pytest.raises(APIConnectionError):
client.execute_service(service, {})
for method in (
client.button_command,
client.climate_command,
client.cover_command,
client.fan_command,
client.light_command,
client.media_player_command,
client.siren_command,
):
with pytest.raises(APIConnectionError):
await method(1)

with pytest.raises(APIConnectionError):
await client.alarm_control_panel_command(1, AlarmControlPanelCommand.ARM_HOME)

with pytest.raises(APIConnectionError):
await client.date_command(1, 1, 1, 1)

with pytest.raises(APIConnectionError):
await client.lock_command(1, LockCommand.LOCK)

with pytest.raises(APIConnectionError):
await client.number_command(1, 1)

with pytest.raises(APIConnectionError):
await client.select_command(1, "1")

with pytest.raises(APIConnectionError):
await client.switch_command(1, True)

with pytest.raises(APIConnectionError):
await client.text_command(1, "1")