Skip to content

Commit

Permalink
Rebase changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Feb 22, 2024
1 parent e257894 commit cbb186c
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 13 deletions.
38 changes: 36 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from tqdm.autonotebook import tqdm

from collections.abc import Iterable
from typing import Union, List, Tuple, Optional, Dict, Any

from pinecone.config import ConfigBuilder
Expand All @@ -22,9 +21,10 @@
DeleteRequest,
UpdateRequest,
DescribeIndexStatsRequest,
ListResponse
)
from pinecone.core.client.api.data_plane_api import DataPlaneApi
from ..utils import get_user_agent, fix_tuple_length
from ..utils import get_user_agent
from .vector_factory import VectorFactory

__all__ = [
Expand Down Expand Up @@ -502,6 +502,40 @@ def describe_index_stats(
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
)

@validate_and_convert_errors
def list_paginated(
self,
prefix: Optional[str] = None,
limit: Optional[int] = None,
pagination_token: Optional[str] = None,
namespace: Optional[str] = None,
**kwargs
) -> ListResponse:
args_dict = self._parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
("namespace", namespace),
("pagination_token", pagination_token),
]
)
return self._vector_api.list(**args_dict, **kwargs)

@validate_and_convert_errors
def list(self, **kwargs):
limit = kwargs.get("limit", 100)
done = False
while not done:
results = self.list_paginated(**kwargs)
if len(results.vectors) > 0:
yield [v.id for v in results.vectors]

full_page = len(results.vectors) == limit
if results.pagination and full_page:
kwargs.update({"pagination_token": results.pagination.next})
else:
done = True

@staticmethod
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
Expand Down
56 changes: 55 additions & 1 deletion pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
QueryResponse,
DescribeIndexStatsResponse,
)
from pinecone.models.list_response import (
ListResponse as SimpleListResponse,
Pagination
)
from pinecone.core.grpc.protos.vector_service_pb2 import (
Vector as GRPCVector,
QueryVector as GRPCQueryVector,
Expand All @@ -22,6 +26,8 @@
QueryRequest,
FetchRequest,
UpdateRequest,
ListRequest,
ListResponse,
DescribeIndexStatsRequest,
DeleteResponse,
UpdateResponse,
Expand All @@ -41,7 +47,6 @@ class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]


class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""

Expand Down Expand Up @@ -429,6 +434,55 @@ def update(
else:
return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout)

def list_paginated(
self,
prefix: Optional[str] = None,
limit: Optional[int] = None,
pagination_token: Optional[str] = None,
namespace: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs
) -> SimpleListResponse:
args_dict = self._parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
("namespace", namespace),
("pagination_token", pagination_token),
]
)
request = ListRequest(**args_dict, **kwargs)
response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout)

if response.pagination and response.pagination.next != '':
pagination = Pagination(next=response.pagination.next)
else:
pagination = None

return SimpleListResponse(
namespace=response.namespace,
vectors=response.vectors,
pagination=pagination,
)

def list(self, **kwargs):
limit = kwargs.get("limit", 100)
done = False
while not done:
try:
results = self.list_paginated(**kwargs)
except Exception as e:
raise e

if len(results.vectors) > 0:
yield [v.id for v in results.vectors]

full_page = len(results.vectors) == limit
if results.pagination and results.pagination.next and full_page:
kwargs.update({"pagination_token": results.pagination.next})
else:
done = True

def describe_index_stats(
self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs
) -> DescribeIndexStatsResponse:
Expand Down
9 changes: 9 additions & 0 deletions pinecone/models/list_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import NamedTuple, Optional, List

class Pagination(NamedTuple):
next: str

class ListResponse(NamedTuple):
namespace: str
vectors: List
pagination: Optional[Pagination]
30 changes: 20 additions & 10 deletions tests/integration/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import json
from ..helpers import get_environment_var, random_string
from .seed import setup_data
from .seed import setup_data, setup_list_data

# Test matrix needs to consider the following dimensions:
# - pod vs serverless
Expand Down Expand Up @@ -41,14 +41,18 @@ def spec():

@pytest.fixture(scope='session')
def index_name():
# return 'dataplane-lol'
return 'dataplane-' + random_string(20)

@pytest.fixture(scope='session')
def namespace():
# return 'banana'
return random_string(10)

@pytest.fixture(scope='session')
def list_namespace():
# return 'list-banana'
return random_string(10)

@pytest.fixture(scope='session')
def idx(client, index_name, index_host):
return client.Index(name=index_name, host=index_host)
Expand All @@ -57,27 +61,33 @@ def idx(client, index_name, index_host):
def index_host(index_name, metric, spec):
pc = build_client()
print('Creating index with name: ' + index_name)
pc.create_index(
name=index_name,
dimension=2,
metric=metric,
spec=spec
)
if index_name not in pc.list_indexes().names():
pc.create_index(
name=index_name,
dimension=2,
metric=metric,
spec=spec
)
description = pc.describe_index(name=index_name)
yield description.host

print('Deleting index with name: ' + index_name)
pc.delete_index(index_name, -1)

@pytest.fixture(scope='session', autouse=True)
def seed_data(idx, namespace, index_host):
def seed_data(idx, namespace, index_host, list_namespace):
print('Seeding data in host ' + index_host)

print('Seeding list data in namespace "' + list_namespace + '"')
setup_list_data(idx, list_namespace, True)

print('Seeding data in namespace "' + namespace + '"')
setup_data(idx, namespace, False)

print('Seeding data in namespace ""')
setup_data(idx, '', True)

print('Waiting a bit more to ensure freshness')
time.sleep(60)
time.sleep(120)

yield
12 changes: 12 additions & 0 deletions tests/integration/data/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ def setup_data(idx, target_namespace, wait):

if wait:
poll_fetch_for_ids_in_namespace(idx, ids=['1', '2', '3', '4', '5', '6', '7', '8', '9'], namespace=target_namespace)

def setup_list_data(idx, target_namespace, wait):
# Upsert a bunch more stuff for testing list pagination
for i in range(0, 1000, 50):
idx.upsert(vectors=[
(str(i+d), embedding_values(2)) for d in range(50)
],
namespace=target_namespace
)

if wait:
poll_fetch_for_ids_in_namespace(idx, ids=['999'], namespace=target_namespace)
100 changes: 100 additions & 0 deletions tests/integration/data/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
class TestListPaginated:
def test_list_when_no_results(self, idx):
results = idx.list_paginated(namespace='no-results')
assert results != None
assert results.namespace == 'no-results'
assert len(results.vectors) == 0
# assert results.pagination == None

def test_list_no_args(self, idx):
results = idx.list_paginated()

assert results != None
assert len(results.vectors) == 9
assert results.namespace == ''
# assert results.pagination == None

def test_list_when_limit(self, idx, list_namespace):
results = idx.list_paginated(limit=10, namespace=list_namespace)

assert results != None
assert len(results.vectors) == 10
assert results.namespace == list_namespace
assert results.pagination != None
assert results.pagination.next != None
assert isinstance(results.pagination.next, str)
assert results.pagination.next != ''

def test_list_when_using_pagination(self, idx, list_namespace):
results = idx.list_paginated(
prefix='99', limit=5, namespace=list_namespace
)
next_results = idx.list_paginated(
prefix='99', limit=5, namespace=list_namespace, pagination_token=results.pagination.next
)
next_next_results = idx.list_paginated(
prefix='99', limit=5, namespace=list_namespace, pagination_token=next_results.pagination.next
)

assert results.namespace == list_namespace
assert len(results.vectors) == 5
assert [v.id for v in results.vectors] == ['99', '990', '991', '992', '993']
assert len(next_results.vectors) == 5
assert [v.id for v in next_results.vectors] == ['994', '995', '996', '997', '998']
assert len(next_next_results.vectors) == 1
assert [v.id for v in next_next_results.vectors] == ['999']
# assert next_next_results.pagination == None

class TestList:
def test_list_with_defaults(self, idx):
pages = []
page_sizes = []
page_count = 0
for ids in idx.list():
page_count += 1
assert ids != None
page_sizes.append(len(ids))
pages.append(ids)

assert page_count == 1
assert page_sizes == [9]

def test_list(self, idx, list_namespace):
results = idx.list(prefix='99', limit=20, namespace=list_namespace)

page_count = 0
for ids in results:
page_count += 1
assert ids != None
assert len(ids) == 11
assert ids == ['99', '990', '991', '992', '993', '994', '995', '996', '997', '998', '999']
assert page_count == 1

def test_list_when_no_results_for_prefix(self, idx, list_namespace):
page_count = 0
for ids in idx.list(prefix='no-results', namespace=list_namespace):
page_count += 1
assert page_count == 0

def test_list_when_no_results_for_namespace(self, idx):
page_count = 0
for ids in idx.list(prefix='99', namespace='no-results'):
page_count += 1
assert page_count == 0

def test_list_when_multiple_pages(self, idx, list_namespace):
pages = []
page_sizes = []
page_count = 0

for ids in idx.list(prefix='99', limit=5, namespace=list_namespace):
page_count += 1
assert ids != None
page_sizes.append(len(ids))
pages.append(ids)

assert page_count == 3
assert page_sizes == [5, 5, 1]
assert pages[0] == ['99', '990', '991', '992', '993']
assert pages[1] == ['994', '995', '996', '997', '998']
assert pages[2] == ['999']
16 changes: 16 additions & 0 deletions tests/integration/data/test_list_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from pinecone import PineconeException
import pytest

class TestListErrors:
def test_list_change_prefix_while_fetching_next_page(self, idx, list_namespace):
results = idx.list_paginated(prefix='99', limit=5, namespace=list_namespace)
with pytest.raises(PineconeException) as e:
idx.list_paginated(prefix='98', limit=5, pagination_token=results.pagination.next)
assert 'prefix' in str(e.value)

@pytest.mark.skip(reason='Bug filed')
def test_list_change_namespace_while_fetching_next_page(self, idx, namespace):
results = idx.list_paginated(limit=5, namespace=namespace)
with pytest.raises(PineconeException) as e:
idx.list_paginated(limit=5, namespace='new-namespace', pagination_token=results.pagination.next)
assert 'namespace' in str(e.value)

0 comments on commit cbb186c

Please sign in to comment.