Skip to content

Commit bcebc92

Browse files
partheagcf-owl-bot[bot]
andauthoredFeb 13, 2024
fix: resolve issue handling protobuf responses in rest streaming (#604)
* fix: resolve issue handling protobuf responses in rest streaming * raise ValueError if response_message_cls is not a subclass of proto.Message or google.protobuf.message.Message * remove response_type from pytest.mark.parametrize * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * add test for ValueError in response_iterator._grab() --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 82c3118 commit bcebc92

File tree

2 files changed

+196
-56
lines changed

2 files changed

+196
-56
lines changed
 

‎google/api_core/rest_streaming.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,31 @@
1616

1717
from collections import deque
1818
import string
19-
from typing import Deque
19+
from typing import Deque, Union
2020

21+
import proto
Has conversations. Original line has conversations.
2122
import requests
23+
import google.protobuf.message
24+
from google.protobuf.json_format import Parse
2225

2326

2427
class ResponseIterator:
2528
"""Iterator over REST API responses.
2629
2730
Args:
2831
response (requests.Response): An API response object.
29-
response_message_cls (Callable[proto.Message]): A proto
32+
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
3033
class expected to be returned from an API.
34+
35+
Raises:
36+
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
3137
"""
3238

33-
def __init__(self, response: requests.Response, response_message_cls):
39+
def __init__(
40+
self,
41+
response: requests.Response,
42+
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
43+
):
3444
self._response = response
3545
self._response_message_cls = response_message_cls
3646
# Inner iterator over HTTP response's content.
@@ -107,7 +117,14 @@ def __next__(self):
107117

108118
def _grab(self):
109119
# Add extra quotes to make json.loads happy.
110-
return self._response_message_cls.from_json(self._ready_objs.popleft())
120+
if issubclass(self._response_message_cls, proto.Message):
121+
return self._response_message_cls.from_json(self._ready_objs.popleft())
122+
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
123+
return Parse(self._ready_objs.popleft(), self._response_message_cls())
124+
else:
125+
raise ValueError(
126+
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
127+
)
111128

112129
def __iter__(self):
113130
return self

‎tests/unit/test_rest_streaming.py

+175-52
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
import requests
2525

2626
from google.api_core import rest_streaming
27+
from google.api import http_pb2
28+
from google.api import httpbody_pb2
2729
from google.protobuf import duration_pb2
2830
from google.protobuf import timestamp_pb2
31+
from google.protobuf.json_format import MessageToJson
2932

3033

3134
__protobuf__ = proto.module(package=__name__)
@@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
98101
# json.dumps returns a string surrounded with quotes that need to be stripped
99102
# in order to be an actual JSON.
100103
json_responses = [
101-
self._response_message_cls.to_json(r).strip('"') for r in responses
104+
self._response_message_cls.to_json(r).strip('"')
105+
if issubclass(self._response_message_cls, proto.Message)
106+
else MessageToJson(r).strip('"')
107+
for r in responses
102108
]
103109
logging.info(f"Sending JSON stream: {json_responses}")
104110
ret_val = "[{}]".format(",".join(json_responses))
@@ -114,103 +120,220 @@ def iter_content(self, *args, **kwargs):
114120
)
115121

116122

117-
@pytest.mark.parametrize("random_split", [False])
118-
def test_next_simple(random_split):
119-
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
123+
@pytest.mark.parametrize(
124+
"random_split,resp_message_is_proto_plus",
125+
[(False, True), (False, False)],
126+
)
127+
def test_next_simple(random_split, resp_message_is_proto_plus):
128+
if resp_message_is_proto_plus:
129+
response_type = EchoResponse
130+
responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
131+
else:
132+
response_type = httpbody_pb2.HttpBody
133+
responses = [
134+
httpbody_pb2.HttpBody(content_type="hello world"),
135+
httpbody_pb2.HttpBody(content_type="yes"),
136+
]
137+
120138
resp = ResponseMock(
121-
responses=responses, random_split=random_split, response_cls=EchoResponse
139+
responses=responses, random_split=random_split, response_cls=response_type
122140
)
123-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
141+
itr = rest_streaming.ResponseIterator(resp, response_type)
124142
assert list(itr) == responses
125143

126144

127-
@pytest.mark.parametrize("random_split", [True, False])
128-
def test_next_nested(random_split):
129-
responses = [
130-
Song(title="some song", composer=Composer(given_name="some name")),
131-
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
132-
]
145+
@pytest.mark.parametrize(
146+
"random_split,resp_message_is_proto_plus",
147+
[
148+
(True, True),
149+
(False, True),
150+
(True, False),
151+
(False, False),
152+
],
153+
)
154+
def test_next_nested(random_split, resp_message_is_proto_plus):
155+
if resp_message_is_proto_plus:
156+
response_type = Song
157+
responses = [
158+
Song(title="some song", composer=Composer(given_name="some name")),
159+
Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
160+
]
161+
else:
162+
# Although `http_pb2.HttpRule`` is used in the response, any response message
163+
# can be used which meets this criteria for the test of having a nested field.
164+
response_type = http_pb2.HttpRule
165+
responses = [
166+
http_pb2.HttpRule(
167+
selector="some selector",
168+
custom=http_pb2.CustomHttpPattern(kind="some kind"),
169+
),
170+
http_pb2.HttpRule(
171+
selector="another selector",
172+
custom=http_pb2.CustomHttpPattern(path="some path"),
173+
),
174+
]
133175
resp = ResponseMock(
134-
responses=responses, random_split=random_split, response_cls=Song
176+
responses=responses, random_split=random_split, response_cls=response_type
135177
)
136-
itr = rest_streaming.ResponseIterator(resp, Song)
178+
itr = rest_streaming.ResponseIterator(resp, response_type)
137179
assert list(itr) == responses
138180

139181

140-
@pytest.mark.parametrize("random_split", [True, False])
141-
def test_next_stress(random_split):
182+
@pytest.mark.parametrize(
183+
"random_split,resp_message_is_proto_plus",
184+
[
185+
(True, True),
186+
(False, True),
187+
(True, False),
188+
(False, False),
189+
],
190+
)
191+
def test_next_stress(random_split, resp_message_is_proto_plus):
142192
n = 50
143-
responses = [
144-
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
145-
for i in range(n)
146-
]
193+
if resp_message_is_proto_plus:
194+
response_type = Song
195+
responses = [
196+
Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
197+
for i in range(n)
198+
]
199+
else:
200+
response_type = http_pb2.HttpRule
201+
responses = [
202+
http_pb2.HttpRule(
203+
selector="selector_%d" % i,
204+
custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
205+
)
206+
for i in range(n)
207+
]
147208
resp = ResponseMock(
148-
responses=responses, random_split=random_split, response_cls=Song
209+
responses=responses, random_split=random_split, response_cls=response_type
149210
)
150-
itr = rest_streaming.ResponseIterator(resp, Song)
211+
itr = rest_streaming.ResponseIterator(resp, response_type)
151212
assert list(itr) == responses
152213

153214

154-
@pytest.mark.parametrize("random_split", [True, False])
155-
def test_next_escaped_characters_in_string(random_split):
156-
composer_with_relateds = Composer()
157-
relateds = ["Artist A", "Artist B"]
158-
composer_with_relateds.relateds = relateds
159-
160-
responses = [
161-
Song(title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")),
162-
Song(
163-
title='{"this is weird": "totally"}', composer=Composer(given_name="\\{}\\")
164-
),
165-
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
166-
]
215+
@pytest.mark.parametrize(
216+
"random_split,resp_message_is_proto_plus",
217+
[
218+
(True, True),
219+
(False, True),
220+
(True, False),
221+
(False, False),
222+
],
223+
)
224+
def test_next_escaped_characters_in_string(random_split, resp_message_is_proto_plus):
225+
if resp_message_is_proto_plus:
226+
response_type = Song
227+
composer_with_relateds = Composer()
228+
relateds = ["Artist A", "Artist B"]
229+
composer_with_relateds.relateds = relateds
230+
231+
responses = [
232+
Song(
233+
title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
234+
),
235+
Song(
236+
title='{"this is weird": "totally"}',
237+
composer=Composer(given_name="\\{}\\"),
238+
),
239+
Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
240+
]
241+
else:
242+
response_type = http_pb2.Http
243+
responses = [
244+
http_pb2.Http(
245+
rules=[
246+
http_pb2.HttpRule(
247+
selector='ti"tle\nfoo\tbar{}',
248+
custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
249+
)
250+
]
251+
),
252+
http_pb2.Http(
253+
rules=[
254+
http_pb2.HttpRule(
255+
selector='{"this is weird": "totally"}',
256+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
257+
)
258+
]
259+
),
260+
http_pb2.Http(
261+
rules=[
262+
http_pb2.HttpRule(
263+
selector='\\{"key": ["value",]}\\',
264+
custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
265+
)
266+
]
267+
),
268+
]
167269
resp = ResponseMock(
168-
responses=responses, random_split=random_split, response_cls=Song
270+
responses=responses, random_split=random_split, response_cls=response_type
169271
)
170-
itr = rest_streaming.ResponseIterator(resp, Song)
272+
itr = rest_streaming.ResponseIterator(resp, response_type)
171273
assert list(itr) == responses
172274

173275

174-
def test_next_not_array():
276+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
277+
def test_next_not_array(response_type):
175278
with patch.object(
176279
ResponseMock, "iter_content", return_value=iter('{"hello": 0}')
177280
) as mock_method:
178-
179-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
180-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
281+
resp = ResponseMock(responses=[], response_cls=response_type)
282+
itr = rest_streaming.ResponseIterator(resp, response_type)
181283
with pytest.raises(ValueError):
182284
next(itr)
183285
mock_method.assert_called_once()
184286

185287

186-
def test_cancel():
288+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
289+
def test_cancel(response_type):
187290
with patch.object(ResponseMock, "close", return_value=None) as mock_method:
188-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
189-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
291+
resp = ResponseMock(responses=[], response_cls=response_type)
292+
itr = rest_streaming.ResponseIterator(resp, response_type)
190293
itr.cancel()
191294
mock_method.assert_called_once()
192295

193296

194-
def test_check_buffer():
297+
@pytest.mark.parametrize(
298+
"response_type,return_value",
299+
[
300+
(EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
301+
(httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
302+
],
303+
)
304+
def test_check_buffer(response_type, return_value):
195305
with patch.object(
196306
ResponseMock,
197307
"_parse_responses",
198-
return_value=bytes('[{"content": "hello"}, {', "utf-8"),
308+
return_value=return_value,
199309
):
200-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
201-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
310+
resp = ResponseMock(responses=[], response_cls=response_type)
311+
itr = rest_streaming.ResponseIterator(resp, response_type)
202312
with pytest.raises(ValueError):
203313
next(itr)
204314
next(itr)
205315

206316

207-
def test_next_html():
317+
@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
318+
def test_next_html(response_type):
208319
with patch.object(
209320
ResponseMock, "iter_content", return_value=iter("<!DOCTYPE html><html></html>")
210321
) as mock_method:
211-
212-
resp = ResponseMock(responses=[], response_cls=EchoResponse)
213-
itr = rest_streaming.ResponseIterator(resp, EchoResponse)
322+
resp = ResponseMock(responses=[], response_cls=response_type)
323+
itr = rest_streaming.ResponseIterator(resp, response_type)
214324
with pytest.raises(ValueError):
215325
next(itr)
216326
mock_method.assert_called_once()
327+
328+
329+
def test_invalid_response_class():
330+
class SomeClass:
331+
pass
332+
333+
resp = ResponseMock(responses=[], response_cls=SomeClass)
334+
response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
335+
with pytest.raises(
336+
ValueError,
337+
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
338+
):
339+
response_iterator._grab()

0 commit comments

Comments
 (0)
Please sign in to comment.