Skip to content

Commit 1b7bb6d

Browse files
ohmayrgcf-owl-bot[bot]
andauthoredSep 18, 2024··
feat: add support for asynchronous rest streaming (#686)
* duplicating file to base * restore original file * duplicate file to async * restore original file * duplicate test file for async * restore test file * feat: add support for asynchronous rest streaming * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fix naming issue * fix import module name * pull auth feature branch * revert setup file * address PR comments * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * run black * address PR comments * update nox coverage * address PR comments * fix nox session name in workflow * use https for remote repo * add context manager methods * address PR comments * update auth error versions * update import error --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent e542124 commit 1b7bb6d

8 files changed

+679
-128
lines changed
 

‎.github/workflows/unittest.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"]
14+
option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "_with_auth_aio"]
1515
python:
1616
- "3.7"
1717
- "3.8"
+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for server-side streaming in REST."""
16+
17+
from collections import deque
18+
import string
19+
from typing import Deque, Union
20+
import types
21+
22+
import proto
23+
import google.protobuf.message
24+
from google.protobuf.json_format import Parse
25+
26+
27+
class BaseResponseIterator:
28+
"""Base Iterator over REST API responses. This class should not be used directly.
29+
30+
Args:
31+
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
32+
class expected to be returned from an API.
33+
34+
Raises:
35+
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
36+
"""
37+
38+
def __init__(
39+
self,
40+
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
41+
):
42+
self._response_message_cls = response_message_cls
43+
# Contains a list of JSON responses ready to be sent to user.
44+
self._ready_objs: Deque[str] = deque()
45+
# Current JSON response being built.
46+
self._obj = ""
47+
# Keeps track of the nesting level within a JSON object.
48+
self._level = 0
49+
# Keeps track whether HTTP response is currently sending values
50+
# inside of a string value.
51+
self._in_string = False
52+
# Whether an escape symbol "\" was encountered.
53+
self._escape_next = False
54+
55+
self._grab = types.MethodType(self._create_grab(), self)
56+
57+
def _process_chunk(self, chunk: str):
58+
if self._level == 0:
59+
if chunk[0] != "[":
60+
raise ValueError(
61+
"Can only parse array of JSON objects, instead got %s" % chunk
62+
)
63+
for char in chunk:
64+
if char == "{":
65+
if self._level == 1:
66+
# Level 1 corresponds to the outermost JSON object
67+
# (i.e. the one we care about).
68+
self._obj = ""
69+
if not self._in_string:
70+
self._level += 1
71+
self._obj += char
72+
elif char == "}":
73+
self._obj += char
74+
if not self._in_string:
75+
self._level -= 1
76+
if not self._in_string and self._level == 1:
77+
self._ready_objs.append(self._obj)
78+
elif char == '"':
79+
# Helps to deal with an escaped quotes inside of a string.
80+
if not self._escape_next:
81+
self._in_string = not self._in_string
82+
self._obj += char
83+
elif char in string.whitespace:
84+
if self._in_string:
85+
self._obj += char
86+
elif char == "[":
87+
if self._level == 0:
88+
self._level += 1
89+
else:
90+
self._obj += char
91+
elif char == "]":
92+
if self._level == 1:
93+
self._level -= 1
94+
else:
95+
self._obj += char
96+
else:
97+
self._obj += char
98+
self._escape_next = not self._escape_next if char == "\\" else False
99+
100+
def _create_grab(self):
101+
if issubclass(self._response_message_cls, proto.Message):
102+
103+
def grab(this):
104+
return this._response_message_cls.from_json(
105+
this._ready_objs.popleft(), ignore_unknown_fields=True
106+
)
107+
108+
return grab
109+
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
110+
111+
def grab(this):
112+
return Parse(this._ready_objs.popleft(), this._response_message_cls())
113+
114+
return grab
115+
else:
116+
raise ValueError(
117+
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
118+
)

‎google/api_core/rest_streaming.py

+8-74
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@
1414

1515
"""Helpers for server-side streaming in REST."""
1616

17-
from collections import deque
18-
import string
19-
from typing import Deque, Union
17+
from typing import Union
2018

2119
import proto
2220
import requests
2321
import google.protobuf.message
24-
from google.protobuf.json_format import Parse
22+
from google.api_core._rest_streaming_base import BaseResponseIterator
2523

2624

27-
class ResponseIterator:
25+
class ResponseIterator(BaseResponseIterator):
2826
"""Iterator over REST API responses.
2927
3028
Args:
@@ -33,7 +31,8 @@ class ResponseIterator:
3331
class expected to be returned from an API.
3432
3533
Raises:
36-
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
34+
ValueError:
35+
- If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
3736
"""
3837

3938
def __init__(
@@ -42,68 +41,16 @@ def __init__(
4241
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
4342
):
4443
self._response = response
45-
self._response_message_cls = response_message_cls
4644
# Inner iterator over HTTP response's content.
4745
self._response_itr = self._response.iter_content(decode_unicode=True)
48-
# Contains a list of JSON responses ready to be sent to user.
49-
self._ready_objs: Deque[str] = deque()
50-
# Current JSON response being built.
51-
self._obj = ""
52-
# Keeps track of the nesting level within a JSON object.
53-
self._level = 0
54-
# Keeps track whether HTTP response is currently sending values
55-
# inside of a string value.
56-
self._in_string = False
57-
# Whether an escape symbol "\" was encountered.
58-
self._escape_next = False
46+
super(ResponseIterator, self).__init__(
47+
response_message_cls=response_message_cls
48+
)
5949

6050
def cancel(self):
6151
"""Cancel existing streaming operation."""
6252
self._response.close()
6353

64-
def _process_chunk(self, chunk: str):
65-
if self._level == 0:
66-
if chunk[0] != "[":
67-
raise ValueError(
68-
"Can only parse array of JSON objects, instead got %s" % chunk
69-
)
70-
for char in chunk:
71-
if char == "{":
72-
if self._level == 1:
73-
# Level 1 corresponds to the outermost JSON object
74-
# (i.e. the one we care about).
75-
self._obj = ""
76-
if not self._in_string:
77-
self._level += 1
78-
self._obj += char
79-
elif char == "}":
80-
self._obj += char
81-
if not self._in_string:
82-
self._level -= 1
83-
if not self._in_string and self._level == 1:
84-
self._ready_objs.append(self._obj)
85-
elif char == '"':
86-
# Helps to deal with an escaped quotes inside of a string.
87-
if not self._escape_next:
88-
self._in_string = not self._in_string
89-
self._obj += char
90-
elif char in string.whitespace:
91-
if self._in_string:
92-
self._obj += char
93-
elif char == "[":
94-
if self._level == 0:
95-
self._level += 1
96-
else:
97-
self._obj += char
98-
elif char == "]":
99-
if self._level == 1:
100-
self._level -= 1
101-
else:
102-
self._obj += char
103-
else:
104-
self._obj += char
105-
self._escape_next = not self._escape_next if char == "\\" else False
106-
10754
def __next__(self):
10855
while not self._ready_objs:
10956
try:
@@ -115,18 +62,5 @@ def __next__(self):
11562
raise e
11663
return self._grab()
11764

118-
def _grab(self):
119-
# Add extra quotes to make json.loads happy.
120-
if issubclass(self._response_message_cls, proto.Message):
121-
return self._response_message_cls.from_json(
122-
self._ready_objs.popleft(), ignore_unknown_fields=True
123-
)
124-
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
125-
return Parse(self._ready_objs.popleft(), self._response_message_cls())
126-
else:
127-
raise ValueError(
128-
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
129-
)
130-
13165
def __iter__(self):
13266
return self
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for asynchronous server-side streaming in REST."""
16+
17+
from typing import Union
18+
19+
import proto
20+
21+
try:
22+
import google.auth.aio.transport
23+
except ImportError as e: # pragma: NO COVER
24+
raise ImportError(
25+
"google-auth>=2.35.0 is required to use asynchronous rest streaming."
26+
) from e
27+
28+
import google.protobuf.message
29+
from google.api_core._rest_streaming_base import BaseResponseIterator
30+
31+
32+
class AsyncResponseIterator(BaseResponseIterator):
33+
"""Asynchronous Iterator over REST API responses.
34+
35+
Args:
36+
response (google.auth.aio.transport.Response): An API response object.
37+
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
38+
class expected to be returned from an API.
39+
40+
Raises:
41+
ValueError:
42+
- If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
43+
"""
44+
45+
def __init__(
46+
self,
47+
response: google.auth.aio.transport.Response,
48+
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
49+
):
50+
self._response = response
51+
self._chunk_size = 1024
52+
self._response_itr = self._response.content().__aiter__()
53+
super(AsyncResponseIterator, self).__init__(
54+
response_message_cls=response_message_cls
55+
)
56+
57+
async def __aenter__(self):
58+
return self
59+
60+
async def cancel(self):
61+
"""Cancel existing streaming operation."""
62+
await self._response.close()
63+
64+
async def __anext__(self):
65+
while not self._ready_objs:
66+
try:
67+
chunk = await self._response_itr.__anext__()
68+
chunk = chunk.decode("utf-8")
69+
self._process_chunk(chunk)
70+
except StopAsyncIteration as e:
71+
if self._level > 0:
72+
raise ValueError("i Unfinished stream: %s" % self._obj)
73+
raise e
74+
except ValueError as e:
75+
raise e
76+
return self._grab()
77+
78+
def __aiter__(self):
79+
return self
80+
81+
async def __aexit__(self, exc_type, exc, tb):
82+
"""Cancel existing async streaming operation."""
83+
await self._response.close()

‎noxfile.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"unit",
3939
"unit_grpc_gcp",
4040
"unit_wo_grpc",
41+
"unit_with_auth_aio",
4142
"cover",
4243
"pytype",
4344
"mypy",
@@ -109,7 +110,7 @@ def install_prerelease_dependencies(session, constraints_path):
109110
session.install(*other_deps)
110111

111112

112-
def default(session, install_grpc=True, prerelease=False):
113+
def default(session, install_grpc=True, prerelease=False, install_auth_aio=False):
113114
"""Default unit test session.
114115
115116
This is intended to be run **without** an interpreter set, so
@@ -144,6 +145,11 @@ def default(session, install_grpc=True, prerelease=False):
144145
f"{constraints_dir}/constraints-{session.python}.txt",
145146
)
146147

148+
if install_auth_aio:
149+
session.install(
150+
"google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad"
151+
)
152+
147153
# Print out package versions of dependencies
148154
session.run(
149155
"python", "-c", "import google.protobuf; print(google.protobuf.__version__)"
@@ -229,6 +235,12 @@ def unit_wo_grpc(session):
229235
default(session, install_grpc=False)
230236

231237

238+
@nox.session(python=PYTHON_VERSIONS)
239+
def unit_with_auth_aio(session):
240+
"""Run the unit test suite with google.auth.aio installed"""
241+
default(session, install_auth_aio=True)
242+
243+
232244
@nox.session(python=DEFAULT_PYTHON_VERSION)
233245
def lint_setup_py(session):
234246
"""Verify that setup.py is valid (including RST check)."""
+378
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# TODO: set random.seed explicitly in each test function.
16+
# See related issue: https://github.com/googleapis/python-api-core/issues/689.
17+
18+
import pytest # noqa: I202
19+
import mock
20+
21+
import datetime
22+
import logging
23+
import random
24+
import time
25+
from typing import List, AsyncIterator
26+
27+
import proto
28+
29+
try:
30+
from google.auth.aio.transport import Response
31+
32+
AUTH_AIO_INSTALLED = True
33+
except ImportError:
34+
AUTH_AIO_INSTALLED = False
35+
36+
if not AUTH_AIO_INSTALLED: # pragma: NO COVER
37+
pytest.skip(
38+
"google-auth>=2.35.0 is required to use asynchronous rest streaming.",
39+
allow_module_level=True,
40+
)
41+
42+
from google.api_core import rest_streaming_async
43+
from google.api import http_pb2
44+
from google.api import httpbody_pb2
45+
46+
47+
from ..helpers import Composer, Song, EchoResponse, parse_responses
48+
49+
50+
__protobuf__ = proto.module(package=__name__)
51+
SEED = int(time.time())
52+
logging.info(f"Starting async rest streaming tests with random seed: {SEED}")
53+
random.seed(SEED)
54+
55+
56+
async def mock_async_gen(data, chunk_size=1):
57+
for i in range(0, len(data)): # pragma: NO COVER
58+
chunk = data[i : i + chunk_size]
59+
yield chunk.encode("utf-8")
60+
61+
62+
class ResponseMock(Response):
63+
class _ResponseItr(AsyncIterator[bytes]):
64+
def __init__(self, _response_bytes: bytes, random_split=False):
65+
self._responses_bytes = _response_bytes
66+
self._idx = 0
67+
self._random_split = random_split
68+
69+
def __aiter__(self):
70+
return self
71+
72+
async def __anext__(self):
73+
if self._idx >= len(self._responses_bytes):
74+
raise StopAsyncIteration
75+
if self._random_split:
76+
n = random.randint(1, len(self._responses_bytes[self._idx :]))
77+
else:
78+
n = 1
79+
x = self._responses_bytes[self._idx : self._idx + n]
80+
self._idx += n
81+
return x
82+
83+
def __init__(
84+
self,
85+
responses: List[proto.Message],
86+
response_cls,
87+
random_split=False,
88+
):
89+
self._responses = responses
90+
self._random_split = random_split
91+
self._response_message_cls = response_cls
92+
93+
def _parse_responses(self):
94+
return parse_responses(self._response_message_cls, self._responses)
95+
96+
@property
97+
async def headers(self):
98+
raise NotImplementedError()
99+
100+
@property
101+
async def status_code(self):
102+
raise NotImplementedError()
103+
104+
async def close(self):
105+
raise NotImplementedError()
106+
107+
async def content(self, chunk_size=None):
108+
itr = self._ResponseItr(
109+
self._parse_responses(), random_split=self._random_split
110+
)
111+
async for chunk in itr:
112+
yield chunk
113+
114+
async def read(self):
115+
raise NotImplementedError()
116+
117+
118+
@pytest.mark.asyncio
119+
@pytest.mark.parametrize(
120+
"random_split,resp_message_is_proto_plus",
121+
[(False, True), (False, False)],
122+
)
123+
async def test_next_simple(random_split, resp_message_is_proto_plus):
124+
if resp_message_is_proto_plus:
125+
response_type = EchoResponse
126+
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
127+
else:
128+
response_type = httpbody_pb2.HttpBody
129+
responses = [
130+
httpbody_pb2.HttpBody(content_type="hello world"),
131+
httpbody_pb2.HttpBody(content_type="yes"),
132+
]
133+
134+
resp = ResponseMock(
135+
responses=responses, random_split=random_split, response_cls=response_type
136+
)
137+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
138+
idx = 0
139+
async for response in itr:
140+
assert response == responses[idx]
141+
idx += 1
142+
143+
144+
@pytest.mark.asyncio
145+
@pytest.mark.parametrize(
146+
"random_split,resp_message_is_proto_plus",
147+
[
148+
(True, True),
149+
(False, True),
150+
(True, False),
151+
(False, False),
152+
],
153+
)
154+
async def test_next_nested(random_split, resp_message_is_proto_plus):
155+
if resp_message_is_proto_plus:
156+
response_type = Song
157+
responses = [
158+
Song(title="some song", composer=Composer(given_name="some name")),
159+
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
160+
]
161+
else:
162+
# Although `http_pb2.HttpRule`` is used in the response, any response message
163+
# can be used which meets this criteria for the test of having a nested field.
164+
response_type = http_pb2.HttpRule
165+
responses = [
166+
http_pb2.HttpRule(
167+
selector="some selector",
168+
custom=http_pb2.CustomHttpPattern(kind="some kind"),
169+
),
170+
http_pb2.HttpRule(
171+
selector="another selector",
172+
custom=http_pb2.CustomHttpPattern(path="some path"),
173+
),
174+
]
175+
resp = ResponseMock(
176+
responses=responses, random_split=random_split, response_cls=response_type
177+
)
178+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
179+
idx = 0
180+
async for response in itr:
181+
assert response == responses[idx]
182+
idx += 1
183+
assert idx == len(responses)
184+
185+
186+
@pytest.mark.asyncio
187+
@pytest.mark.parametrize(
188+
"random_split,resp_message_is_proto_plus",
189+
[
190+
(True, True),
191+
(False, True),
192+
(True, False),
193+
(False, False),
194+
],
195+
)
196+
async def test_next_stress(random_split, resp_message_is_proto_plus):
197+
n = 50
198+
if resp_message_is_proto_plus:
199+
response_type = Song
200+
responses = [
201+
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
202+
for i in range(n)
203+
]
204+
else:
205+
response_type = http_pb2.HttpRule
206+
responses = [
207+
http_pb2.HttpRule(
208+
selector="selector_%d" % i,
209+
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
210+
)
211+
for i in range(n)
212+
]
213+
resp = ResponseMock(
214+
responses=responses, random_split=random_split, response_cls=response_type
215+
)
216+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
217+
idx = 0
218+
async for response in itr:
219+
assert response == responses[idx]
220+
idx += 1
221+
assert idx == n
222+
223+
224+
@pytest.mark.asyncio
225+
@pytest.mark.parametrize(
226+
"random_split,resp_message_is_proto_plus",
227+
[
228+
(True, True),
229+
(False, True),
230+
(True, False),
231+
(False, False),
232+
],
233+
)
234+
async def test_next_escaped_characters_in_string(
235+
random_split, resp_message_is_proto_plus
236+
):
237+
if resp_message_is_proto_plus:
238+
response_type = Song
239+
composer_with_relateds = Composer()
240+
relateds = ["Artist A", "Artist B"]
241+
composer_with_relateds.relateds = relateds
242+
243+
responses = [
244+
Song(
245+
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
246+
),
247+
Song(
248+
title='{"this is weird": "totally"}',
249+
composer=Composer(given_name="\\{}\\"),
250+
),
251+
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
252+
]
253+
else:
254+
response_type = http_pb2.Http
255+
responses = [
256+
http_pb2.Http(
257+
rules=[
258+
http_pb2.HttpRule(
259+
selector='ti"tle\nfoo\tbar{}',
260+
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
261+
)
262+
]
263+
),
264+
http_pb2.Http(
265+
rules=[
266+
http_pb2.HttpRule(
267+
selector='{"this is weird": "totally"}',
268+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
269+
)
270+
]
271+
),
272+
http_pb2.Http(
273+
rules=[
274+
http_pb2.HttpRule(
275+
selector='\\{"key": ["value",]}\\',
276+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
277+
)
278+
]
279+
),
280+
]
281+
resp = ResponseMock(
282+
responses=responses, random_split=random_split, response_cls=response_type
283+
)
284+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
285+
idx = 0
286+
async for response in itr:
287+
assert response == responses[idx]
288+
idx += 1
289+
assert idx == len(responses)
290+
291+
292+
@pytest.mark.asyncio
293+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
294+
async def test_next_not_array(response_type):
295+
296+
data = '{"hello": 0}'
297+
with mock.patch.object(
298+
ResponseMock, "content", return_value=mock_async_gen(data)
299+
) as mock_method:
300+
resp = ResponseMock(responses=[], response_cls=response_type)
301+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
302+
with pytest.raises(ValueError):
303+
await itr.__anext__()
304+
mock_method.assert_called_once()
305+
306+
307+
@pytest.mark.asyncio
308+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
309+
async def test_cancel(response_type):
310+
with mock.patch.object(
311+
ResponseMock, "close", new_callable=mock.AsyncMock
312+
) as mock_method:
313+
resp = ResponseMock(responses=[], response_cls=response_type)
314+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
315+
await itr.cancel()
316+
mock_method.assert_called_once()
317+
318+
319+
@pytest.mark.asyncio
320+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
321+
async def test_iterator_as_context_manager(response_type):
322+
with mock.patch.object(
323+
ResponseMock, "close", new_callable=mock.AsyncMock
324+
) as mock_method:
325+
resp = ResponseMock(responses=[], response_cls=response_type)
326+
async with rest_streaming_async.AsyncResponseIterator(resp, response_type):
327+
pass
328+
mock_method.assert_called_once()
329+
330+
331+
@pytest.mark.asyncio
332+
@pytest.mark.parametrize(
333+
"response_type,return_value",
334+
[
335+
(EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
336+
(httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
337+
],
338+
)
339+
async def test_check_buffer(response_type, return_value):
340+
with mock.patch.object(
341+
ResponseMock,
342+
"_parse_responses",
343+
return_value=return_value,
344+
):
345+
resp = ResponseMock(responses=[], response_cls=response_type)
346+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
347+
with pytest.raises(ValueError):
348+
await itr.__anext__()
349+
await itr.__anext__()
350+
351+
352+
@pytest.mark.asyncio
353+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
354+
async def test_next_html(response_type):
355+
356+
data = "<!DOCTYPE html><html></html>"
357+
with mock.patch.object(
358+
ResponseMock, "content", return_value=mock_async_gen(data)
359+
) as mock_method:
360+
resp = ResponseMock(responses=[], response_cls=response_type)
361+
362+
itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
363+
with pytest.raises(ValueError):
364+
await itr.__anext__()
365+
mock_method.assert_called_once()
366+
367+
368+
@pytest.mark.asyncio
369+
async def test_invalid_response_class():
370+
class SomeClass:
371+
pass
372+
373+
resp = ResponseMock(responses=[], response_cls=SomeClass)
374+
with pytest.raises(
375+
ValueError,
376+
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
377+
):
378+
rest_streaming_async.AsyncResponseIterator(resp, SomeClass)

‎tests/helpers.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for tests"""
16+
17+
import logging
18+
from typing import List
19+
20+
import proto
21+
22+
from google.protobuf import duration_pb2
23+
from google.protobuf import timestamp_pb2
24+
from google.protobuf.json_format import MessageToJson
25+
26+
27+
class Genre(proto.Enum):
28+
GENRE_UNSPECIFIED = 0
29+
CLASSICAL = 1
30+
JAZZ = 2
31+
ROCK = 3
32+
33+
34+
class Composer(proto.Message):
35+
given_name = proto.Field(proto.STRING, number=1)
36+
family_name = proto.Field(proto.STRING, number=2)
37+
relateds = proto.RepeatedField(proto.STRING, number=3)
38+
indices = proto.MapField(proto.STRING, proto.STRING, number=4)
39+
40+
41+
class Song(proto.Message):
42+
composer = proto.Field(Composer, number=1)
43+
title = proto.Field(proto.STRING, number=2)
44+
lyrics = proto.Field(proto.STRING, number=3)
45+
year = proto.Field(proto.INT32, number=4)
46+
genre = proto.Field(Genre, number=5)
47+
is_five_mins_longer = proto.Field(proto.BOOL, number=6)
48+
score = proto.Field(proto.DOUBLE, number=7)
49+
likes = proto.Field(proto.INT64, number=8)
50+
duration = proto.Field(duration_pb2.Duration, number=9)
51+
date_added = proto.Field(timestamp_pb2.Timestamp, number=10)
52+
53+
54+
class EchoResponse(proto.Message):
55+
content = proto.Field(proto.STRING, number=1)
56+
57+
58+
def parse_responses(response_message_cls, all_responses: List[proto.Message]) -> bytes:
59+
# json.dumps returns a string surrounded with quotes that need to be stripped
60+
# in order to be an actual JSON.
61+
json_responses = [
62+
(
63+
response_message_cls.to_json(response).strip('"')
64+
if issubclass(response_message_cls, proto.Message)
65+
else MessageToJson(response).strip('"')
66+
)
67+
for response in all_responses
68+
]
69+
logging.info(f"Sending JSON stream: {json_responses}")
70+
ret_val = "[{}]".format(",".join(json_responses))
71+
return bytes(ret_val, "utf-8")

‎tests/unit/test_rest_streaming.py

+7-52
Original file line numberDiff line numberDiff line change
@@ -26,48 +26,16 @@
2626
from google.api_core import rest_streaming
2727
from google.api import http_pb2
2828
from google.api import httpbody_pb2
29-
from google.protobuf import duration_pb2
30-
from google.protobuf import timestamp_pb2
31-
from google.protobuf.json_format import MessageToJson
29+
30+
from ..helpers import Composer, Song, EchoResponse, parse_responses
3231

3332

3433
__protobuf__ = proto.module(package=__name__)
3534
SEED = int(time.time())
36-
logging.info(f"Starting rest streaming tests with random seed: {SEED}")
35+
logging.info(f"Starting sync rest streaming tests with random seed: {SEED}")
3736
random.seed(SEED)
3837

3938

40-
class Genre(proto.Enum):
41-
GENRE_UNSPECIFIED = 0
42-
CLASSICAL = 1
43-
JAZZ = 2
44-
ROCK = 3
45-
46-
47-
class Composer(proto.Message):
48-
given_name = proto.Field(proto.STRING, number=1)
49-
family_name = proto.Field(proto.STRING, number=2)
50-
relateds = proto.RepeatedField(proto.STRING, number=3)
51-
indices = proto.MapField(proto.STRING, proto.STRING, number=4)
52-
53-
54-
class Song(proto.Message):
55-
composer = proto.Field(Composer, number=1)
56-
title = proto.Field(proto.STRING, number=2)
57-
lyrics = proto.Field(proto.STRING, number=3)
58-
year = proto.Field(proto.INT32, number=4)
59-
genre = proto.Field(Genre, number=5)
60-
is_five_mins_longer = proto.Field(proto.BOOL, number=6)
61-
score = proto.Field(proto.DOUBLE, number=7)
62-
likes = proto.Field(proto.INT64, number=8)
63-
duration = proto.Field(duration_pb2.Duration, number=9)
64-
date_added = proto.Field(timestamp_pb2.Timestamp, number=10)
65-
66-
67-
class EchoResponse(proto.Message):
68-
content = proto.Field(proto.STRING, number=1)
69-
70-
7139
class ResponseMock(requests.Response):
7240
class _ResponseItr:
7341
def __init__(self, _response_bytes: bytes, random_split=False):
@@ -97,27 +65,15 @@ def __init__(
9765
self._random_split = random_split
9866
self._response_message_cls = response_cls
9967

100-
def _parse_responses(self, responses: List[proto.Message]) -> bytes:
101-
# json.dumps returns a string surrounded with quotes that need to be stripped
102-
# in order to be an actual JSON.
103-
json_responses = [
104-
(
105-
self._response_message_cls.to_json(r).strip('"')
106-
if issubclass(self._response_message_cls, proto.Message)
107-
else MessageToJson(r).strip('"')
108-
)
109-
for r in responses
110-
]
111-
logging.info(f"Sending JSON stream: {json_responses}")
112-
ret_val = "[{}]".format(",".join(json_responses))
113-
return bytes(ret_val, "utf-8")
68+
def _parse_responses(self):
69+
return parse_responses(self._response_message_cls, self._responses)
11470

11571
def close(self):
11672
raise NotImplementedError()
11773

11874
def iter_content(self, *args, **kwargs):
11975
return self._ResponseItr(
120-
self._parse_responses(self._responses),
76+
self._parse_responses(),
12177
random_split=self._random_split,
12278
)
12379

@@ -333,9 +289,8 @@ class SomeClass:
333289
pass
334290

335291
resp = ResponseMock(responses=[], response_cls=SomeClass)
336-
response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
337292
with pytest.raises(
338293
ValueError,
339294
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
340295
):
341-
response_iterator._grab()
296+
rest_streaming.ResponseIterator(resp, SomeClass)

0 commit comments

Comments
 (0)
Please sign in to comment.