Skip to content

Commit

Permalink
Ensure all command and service calls raise when disconnected (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Mar 10, 2024
1 parent a300909 commit eabc3d4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
9 changes: 6 additions & 3 deletions aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ def cover_command(
tilt: float | None = None,
stop: bool = False,
) -> None:
connection = self._get_connection()
req = CoverCommandRequest(key=key)
apiv = self.api_version
if TYPE_CHECKING:
Expand All @@ -951,7 +952,7 @@ def cover_command(
elif position == 0.0:
req.legacy_command = LegacyCoverCommand.CLOSE
req.has_legacy_command = True
self._get_connection().send_message(req)
connection.send_message(req)

def fan_command(
self,
Expand Down Expand Up @@ -1058,6 +1059,7 @@ def climate_command( # pylint: disable=too-many-branches
custom_preset: str | None = None,
target_humidity: float | None = None,
) -> None:
connection = self._get_connection()
req = ClimateCommandRequest(key=key)
if mode is not None:
req.has_mode = True
Expand Down Expand Up @@ -1096,7 +1098,7 @@ def climate_command( # pylint: disable=too-many-branches
if target_humidity is not None:
req.has_target_humidity = True
req.target_humidity = target_humidity
self._get_connection().send_message(req)
connection.send_message(req)

def number_command(self, key: int, state: float) -> None:
self._get_connection().send_message(NumberCommandRequest(key=key, state=state))
Expand Down Expand Up @@ -1172,6 +1174,7 @@ def text_command(self, key: int, state: str) -> None:
def execute_service(
self, service: UserService, data: ExecuteServiceDataType
) -> None:
connection = self._get_connection()
req = ExecuteServiceRequest(key=service.key)
args = []
apiv = self.api_version
Expand All @@ -1196,7 +1199,7 @@ def execute_service(
# pylint: disable=no-member
req.args.extend(args)

self._get_connection().send_message(req)
connection.send_message(req)

def _request_image(self, *, single: bool = False, stream: bool = False) -> None:
self._get_connection().send_message(
Expand Down
52 changes: 52 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,3 +2280,55 @@ async def test_api_version_after_connection_closed(
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None


@pytest.mark.asyncio
async def test_calls_after_connection_closed(
api_client: tuple[
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
],
) -> None:
"""Test calls after connection close should raise APIConnectionError."""
client, connection, transport, protocol = api_client
assert client.api_version == APIVersion(1, 9)
await client.disconnect(force=True)
assert client.api_version is None
service = UserService(
name="my_service",
key=1,
args=[],
)
with pytest.raises(APIConnectionError):
client.execute_service(service, {})
for method in (
client.button_command,
client.climate_command,
client.cover_command,
client.fan_command,
client.light_command,
client.media_player_command,
client.siren_command,
):
with pytest.raises(APIConnectionError):
await method(1)

with pytest.raises(APIConnectionError):
await client.alarm_control_panel_command(1, AlarmControlPanelCommand.ARM_HOME)

with pytest.raises(APIConnectionError):
await client.date_command(1, 1, 1, 1)

with pytest.raises(APIConnectionError):
await client.lock_command(1, LockCommand.LOCK)

with pytest.raises(APIConnectionError):
await client.number_command(1, 1)

with pytest.raises(APIConnectionError):
await client.select_command(1, "1")

with pytest.raises(APIConnectionError):
await client.switch_command(1, True)

with pytest.raises(APIConnectionError):
await client.text_command(1, "1")

0 comments on commit eabc3d4

Please sign in to comment.