24
24
import requests
25
25
26
26
from google .api_core import rest_streaming
27
+ from google .api import http_pb2
28
+ from google .api import httpbody_pb2
27
29
from google .protobuf import duration_pb2
28
30
from google .protobuf import timestamp_pb2
31
+ from google .protobuf .json_format import MessageToJson
29
32
30
33
31
34
__protobuf__ = proto .module (package = __name__ )
@@ -98,7 +101,10 @@ def _parse_responses(self, responses: List[proto.Message]) -> bytes:
98
101
# json.dumps returns a string surrounded with quotes that need to be stripped
99
102
# in order to be an actual JSON.
100
103
json_responses = [
101
- self ._response_message_cls .to_json (r ).strip ('"' ) for r in responses
104
+ self ._response_message_cls .to_json (r ).strip ('"' )
105
+ if issubclass (self ._response_message_cls , proto .Message )
106
+ else MessageToJson (r ).strip ('"' )
107
+ for r in responses
102
108
]
103
109
logging .info (f"Sending JSON stream: { json_responses } " )
104
110
ret_val = "[{}]" .format ("," .join (json_responses ))
@@ -114,103 +120,220 @@ def iter_content(self, *args, **kwargs):
114
120
)
115
121
116
122
117
- @pytest .mark .parametrize ("random_split" , [False ])
118
- def test_next_simple (random_split ):
119
- responses = [EchoResponse (content = "hello world" ), EchoResponse (content = "yes" )]
123
+ @pytest .mark .parametrize (
124
+ "random_split,resp_message_is_proto_plus" ,
125
+ [(False , True ), (False , False )],
126
+ )
127
+ def test_next_simple (random_split , resp_message_is_proto_plus ):
128
+ if resp_message_is_proto_plus :
129
+ response_type = EchoResponse
130
+ responses = [EchoResponse (content = "hello world" ), EchoResponse (content = "yes" )]
131
+ else :
132
+ response_type = httpbody_pb2 .HttpBody
133
+ responses = [
134
+ httpbody_pb2 .HttpBody (content_type = "hello world" ),
135
+ httpbody_pb2 .HttpBody (content_type = "yes" ),
136
+ ]
137
+
120
138
resp = ResponseMock (
121
- responses = responses , random_split = random_split , response_cls = EchoResponse
139
+ responses = responses , random_split = random_split , response_cls = response_type
122
140
)
123
- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
141
+ itr = rest_streaming .ResponseIterator (resp , response_type )
124
142
assert list (itr ) == responses
125
143
126
144
127
- @pytest .mark .parametrize ("random_split" , [True , False ])
128
- def test_next_nested (random_split ):
129
- responses = [
130
- Song (title = "some song" , composer = Composer (given_name = "some name" )),
131
- Song (title = "another song" , date_added = datetime .datetime (2021 , 12 , 17 )),
132
- ]
145
+ @pytest .mark .parametrize (
146
+ "random_split,resp_message_is_proto_plus" ,
147
+ [
148
+ (True , True ),
149
+ (False , True ),
150
+ (True , False ),
151
+ (False , False ),
152
+ ],
153
+ )
154
+ def test_next_nested (random_split , resp_message_is_proto_plus ):
155
+ if resp_message_is_proto_plus :
156
+ response_type = Song
157
+ responses = [
158
+ Song (title = "some song" , composer = Composer (given_name = "some name" )),
159
+ Song (title = "another song" , date_added = datetime .datetime (2021 , 12 , 17 )),
160
+ ]
161
+ else :
162
+ # Although `http_pb2.HttpRule`` is used in the response, any response message
163
+ # can be used which meets this criteria for the test of having a nested field.
164
+ response_type = http_pb2 .HttpRule
165
+ responses = [
166
+ http_pb2 .HttpRule (
167
+ selector = "some selector" ,
168
+ custom = http_pb2 .CustomHttpPattern (kind = "some kind" ),
169
+ ),
170
+ http_pb2 .HttpRule (
171
+ selector = "another selector" ,
172
+ custom = http_pb2 .CustomHttpPattern (path = "some path" ),
173
+ ),
174
+ ]
133
175
resp = ResponseMock (
134
- responses = responses , random_split = random_split , response_cls = Song
176
+ responses = responses , random_split = random_split , response_cls = response_type
135
177
)
136
- itr = rest_streaming .ResponseIterator (resp , Song )
178
+ itr = rest_streaming .ResponseIterator (resp , response_type )
137
179
assert list (itr ) == responses
138
180
139
181
140
- @pytest .mark .parametrize ("random_split" , [True , False ])
141
- def test_next_stress (random_split ):
182
+ @pytest .mark .parametrize (
183
+ "random_split,resp_message_is_proto_plus" ,
184
+ [
185
+ (True , True ),
186
+ (False , True ),
187
+ (True , False ),
188
+ (False , False ),
189
+ ],
190
+ )
191
+ def test_next_stress (random_split , resp_message_is_proto_plus ):
142
192
n = 50
143
- responses = [
144
- Song (title = "title_%d" % i , composer = Composer (given_name = "name_%d" % i ))
145
- for i in range (n )
146
- ]
193
+ if resp_message_is_proto_plus :
194
+ response_type = Song
195
+ responses = [
196
+ Song (title = "title_%d" % i , composer = Composer (given_name = "name_%d" % i ))
197
+ for i in range (n )
198
+ ]
199
+ else :
200
+ response_type = http_pb2 .HttpRule
201
+ responses = [
202
+ http_pb2 .HttpRule (
203
+ selector = "selector_%d" % i ,
204
+ custom = http_pb2 .CustomHttpPattern (path = "path_%d" % i ),
205
+ )
206
+ for i in range (n )
207
+ ]
147
208
resp = ResponseMock (
148
- responses = responses , random_split = random_split , response_cls = Song
209
+ responses = responses , random_split = random_split , response_cls = response_type
149
210
)
150
- itr = rest_streaming .ResponseIterator (resp , Song )
211
+ itr = rest_streaming .ResponseIterator (resp , response_type )
151
212
assert list (itr ) == responses
152
213
153
214
154
- @pytest .mark .parametrize ("random_split" , [True , False ])
155
- def test_next_escaped_characters_in_string (random_split ):
156
- composer_with_relateds = Composer ()
157
- relateds = ["Artist A" , "Artist B" ]
158
- composer_with_relateds .relateds = relateds
159
-
160
- responses = [
161
- Song (title = 'ti"tle\n foo\t bar{}' , composer = Composer (given_name = "name\n \n \n " )),
162
- Song (
163
- title = '{"this is weird": "totally"}' , composer = Composer (given_name = "\\ {}\\ " )
164
- ),
165
- Song (title = '\\ {"key": ["value",]}\\ ' , composer = composer_with_relateds ),
166
- ]
215
+ @pytest .mark .parametrize (
216
+ "random_split,resp_message_is_proto_plus" ,
217
+ [
218
+ (True , True ),
219
+ (False , True ),
220
+ (True , False ),
221
+ (False , False ),
222
+ ],
223
+ )
224
+ def test_next_escaped_characters_in_string (random_split , resp_message_is_proto_plus ):
225
+ if resp_message_is_proto_plus :
226
+ response_type = Song
227
+ composer_with_relateds = Composer ()
228
+ relateds = ["Artist A" , "Artist B" ]
229
+ composer_with_relateds .relateds = relateds
230
+
231
+ responses = [
232
+ Song (
233
+ title = 'ti"tle\n foo\t bar{}' , composer = Composer (given_name = "name\n \n \n " )
234
+ ),
235
+ Song (
236
+ title = '{"this is weird": "totally"}' ,
237
+ composer = Composer (given_name = "\\ {}\\ " ),
238
+ ),
239
+ Song (title = '\\ {"key": ["value",]}\\ ' , composer = composer_with_relateds ),
240
+ ]
241
+ else :
242
+ response_type = http_pb2 .Http
243
+ responses = [
244
+ http_pb2 .Http (
245
+ rules = [
246
+ http_pb2 .HttpRule (
247
+ selector = 'ti"tle\n foo\t bar{}' ,
248
+ custom = http_pb2 .CustomHttpPattern (kind = "name\n \n \n " ),
249
+ )
250
+ ]
251
+ ),
252
+ http_pb2 .Http (
253
+ rules = [
254
+ http_pb2 .HttpRule (
255
+ selector = '{"this is weird": "totally"}' ,
256
+ custom = http_pb2 .CustomHttpPattern (kind = "\\ {}\\ " ),
257
+ )
258
+ ]
259
+ ),
260
+ http_pb2 .Http (
261
+ rules = [
262
+ http_pb2 .HttpRule (
263
+ selector = '\\ {"key": ["value",]}\\ ' ,
264
+ custom = http_pb2 .CustomHttpPattern (kind = "\\ {}\\ " ),
265
+ )
266
+ ]
267
+ ),
268
+ ]
167
269
resp = ResponseMock (
168
- responses = responses , random_split = random_split , response_cls = Song
270
+ responses = responses , random_split = random_split , response_cls = response_type
169
271
)
170
- itr = rest_streaming .ResponseIterator (resp , Song )
272
+ itr = rest_streaming .ResponseIterator (resp , response_type )
171
273
assert list (itr ) == responses
172
274
173
275
174
- def test_next_not_array ():
276
+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
277
+ def test_next_not_array (response_type ):
175
278
with patch .object (
176
279
ResponseMock , "iter_content" , return_value = iter ('{"hello": 0}' )
177
280
) as mock_method :
178
-
179
- resp = ResponseMock (responses = [], response_cls = EchoResponse )
180
- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
281
+ resp = ResponseMock (responses = [], response_cls = response_type )
282
+ itr = rest_streaming .ResponseIterator (resp , response_type )
181
283
with pytest .raises (ValueError ):
182
284
next (itr )
183
285
mock_method .assert_called_once ()
184
286
185
287
186
- def test_cancel ():
288
+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
289
+ def test_cancel (response_type ):
187
290
with patch .object (ResponseMock , "close" , return_value = None ) as mock_method :
188
- resp = ResponseMock (responses = [], response_cls = EchoResponse )
189
- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
291
+ resp = ResponseMock (responses = [], response_cls = response_type )
292
+ itr = rest_streaming .ResponseIterator (resp , response_type )
190
293
itr .cancel ()
191
294
mock_method .assert_called_once ()
192
295
193
296
194
- def test_check_buffer ():
297
+ @pytest .mark .parametrize (
298
+ "response_type,return_value" ,
299
+ [
300
+ (EchoResponse , bytes ('[{"content": "hello"}, {' , "utf-8" )),
301
+ (httpbody_pb2 .HttpBody , bytes ('[{"content_type": "hello"}, {' , "utf-8" )),
302
+ ],
303
+ )
304
+ def test_check_buffer (response_type , return_value ):
195
305
with patch .object (
196
306
ResponseMock ,
197
307
"_parse_responses" ,
198
- return_value = bytes ( '[{"content": "hello"}, {' , "utf-8" ) ,
308
+ return_value = return_value ,
199
309
):
200
- resp = ResponseMock (responses = [], response_cls = EchoResponse )
201
- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
310
+ resp = ResponseMock (responses = [], response_cls = response_type )
311
+ itr = rest_streaming .ResponseIterator (resp , response_type )
202
312
with pytest .raises (ValueError ):
203
313
next (itr )
204
314
next (itr )
205
315
206
316
207
- def test_next_html ():
317
+ @pytest .mark .parametrize ("response_type" , [EchoResponse , httpbody_pb2 .HttpBody ])
318
+ def test_next_html (response_type ):
208
319
with patch .object (
209
320
ResponseMock , "iter_content" , return_value = iter ("<!DOCTYPE html><html></html>" )
210
321
) as mock_method :
211
-
212
- resp = ResponseMock (responses = [], response_cls = EchoResponse )
213
- itr = rest_streaming .ResponseIterator (resp , EchoResponse )
322
+ resp = ResponseMock (responses = [], response_cls = response_type )
323
+ itr = rest_streaming .ResponseIterator (resp , response_type )
214
324
with pytest .raises (ValueError ):
215
325
next (itr )
216
326
mock_method .assert_called_once ()
327
+
328
+
329
+ def test_invalid_response_class ():
330
+ class SomeClass :
331
+ pass
332
+
333
+ resp = ResponseMock (responses = [], response_cls = SomeClass )
334
+ response_iterator = rest_streaming .ResponseIterator (resp , SomeClass )
335
+ with pytest .raises (
336
+ ValueError ,
337
+ match = "Response message class must be a subclass of proto.Message or google.protobuf.message.Message" ,
338
+ ):
339
+ response_iterator ._grab ()