Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Configuring SSL proxy via openapi_config object #321

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions pinecone/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def build(
if not host:
raise PineconeConfigurationError("You haven't specified a host.")

openapi_config = (
openapi_config
or kwargs.pop("openapi_config", None)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or kwargs.pop("openapi_config", None) wasn't doing anything, since openapi_config is a named param and the key should never appear in kwargs.

or OpenApiConfigFactory.build(api_key=api_key, host=host)
)
if openapi_config:
openapi_config.host = host
openapi_config.api_key = {"ApiKeyAuth": api_key}
Comment on lines +52 to +53
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the user provides this object with some configuration in it, we want to merge the api key and host settings into it rather than building a fresh object.

else:
openapi_config = OpenApiConfigFactory.build(api_key=api_key, host=host)

return Config(api_key, host, openapi_config, additional_headers)
36 changes: 22 additions & 14 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pinecone.config import PineconeConfig, Config

from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.core.client.api_client import ApiClient
from pinecone.utils import get_user_agent, normalize_host
from pinecone.utils import normalize_host, setup_openapi_client
from pinecone.core.client.models import (
CreateCollectionRequest,
CreateIndexRequest,
Expand Down Expand Up @@ -85,25 +84,20 @@ def __init__(
or share with Pinecone support. **Be very careful with this option, as it will print out
your API key** which forms part of a required authentication header. Default: `false`
"""
if config or kwargs.get("config"):
configKwarg = config or kwargs.get("config")
if not isinstance(configKwarg, Config):
if config:
if not isinstance(config, Config):
raise TypeError("config must be of type pinecone.config.Config")
else:
self.config = configKwarg
Comment on lines -88 to -93
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was more cleanup since config is a named param that should never appear in kwargs

self.config = config
else:
self.config = PineconeConfig.build(api_key=api_key, host=host, additional_headers=additional_headers, **kwargs)

self.pool_threads = pool_threads

if index_api:
self.index_api = index_api
else:
api_client = ApiClient(configuration=self.config.openapi_config, pool_threads=self.pool_threads)
api_client.user_agent = get_user_agent()
extra_headers = self.config.additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)
self.index_api = ManageIndexesApi(api_client)
self.index_api = setup_openapi_client(ManageIndexesApi, self.config, pool_threads)

self.index_host_store = IndexHostStore()
""" @private """
Expand Down Expand Up @@ -521,12 +515,26 @@ def Index(self, name: str = '', host: str = '', **kwargs):
raise ValueError("Either name or host must be specified")

pt = kwargs.pop('pool_threads', None) or self.pool_threads
api_key = self.config.api_key
openapi_config = self.config.openapi_config

if host != '':
# Use host url if it is provided
return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=pt, **kwargs)
return Index(
host=normalize_host(host),
api_key=api_key,
pool_threads=pt,
openapi_config=openapi_config,
**kwargs
)

if name != '':
# Otherwise, get host url from describe_index using the index name
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
return Index(api_key=self.config.api_key, host=index_host, pool_threads=pt, **kwargs)
return Index(
host=index_host,
api_key=api_key,
pool_threads=pt,
openapi_config=openapi_config,
**kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be out of the scope for this change since you're just adding the openapi_config param, but it looks like these two blocks are largely identical, differing only by how the host param is init'd

26 changes: 11 additions & 15 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ListResponse
)
from pinecone.core.client.api.data_plane_api import DataPlaneApi
from ..utils import get_user_agent
from ..utils import setup_openapi_client
from .vector_factory import VectorFactory

__all__ = [
Expand Down Expand Up @@ -75,27 +75,23 @@ def __init__(
host: str,
pool_threads: Optional[int] = 1,
additional_headers: Optional[Dict[str, str]] = {},
openapi_config = None,
**kwargs
):
self._config = ConfigBuilder.build(api_key=api_key, host=host, **kwargs)

api_client = ApiClient(configuration=self._config.openapi_config,
pool_threads=pool_threads)

# Configure request headers
api_client.user_agent = get_user_agent()
extra_headers = additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)

self._api_client = api_client
self._vector_api = DataPlaneApi(api_client=api_client)
self._config = ConfigBuilder.build(
api_key=api_key,
host=host,
additional_headers=additional_headers,
openapi_config=openapi_config,
**kwargs
)
self._vector_api = setup_openapi_client(DataPlaneApi, self._config, pool_threads)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self._api_client.close()
self._vector_api.api_client.close()

@validate_and_convert_errors
def upsert(
Expand Down
3 changes: 2 additions & 1 deletion pinecone/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .deprecation_notice import warn_deprecated
from .fix_tuple_length import fix_tuple_length
from .convert_to_list import convert_to_list
from .normalize_host import normalize_host
from .normalize_host import normalize_host
from .setup_openapi_client import setup_openapi_client
14 changes: 14 additions & 0 deletions pinecone/utils/setup_openapi_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pinecone.core.client.api_client import ApiClient
from .user_agent import get_user_agent

def setup_openapi_client(api_klass, config, pool_threads):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _klass?

api_client = ApiClient(
configuration=config.openapi_config,
pool_threads=pool_threads
)
api_client.user_agent = get_user_agent()
extra_headers = config.additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)
client = api_klass(api_client)
return client
4 changes: 4 additions & 0 deletions tests/integration/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def build_client():
from pinecone import Pinecone
return Pinecone(api_key=api_key(), additional_headers={'sdk-test-suite': 'pinecone-python-client'})

@pytest.fixture(scope='session')
def api_key_fixture():
return api_key()

@pytest.fixture(scope='session')
def client():
return build_client()
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/data/test_openapi_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import os

from pinecone import Pinecone
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
from urllib3 import make_headers

@pytest.mark.skipif(os.getenv('USE_GRPC') != 'false', reason='Only test when using REST')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. 👍

class TestIndexOpenapiConfig:
def test_passing_openapi_config(self, api_key_fixture, index_host):
oai_config = OpenApiConfiguration.get_default_copy()
p = Pinecone(api_key=api_key_fixture, openapi_config=oai_config)
assert p.config.api_key == api_key_fixture
p.list_indexes() # should not throw

index = p.Index(host=index_host)
assert index._config.api_key == api_key_fixture
index.describe_index_stats()
23 changes: 22 additions & 1 deletion tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest
import os

from urllib3 import make_headers

class TestConfig:
@pytest.fixture(autouse=True)
def run_before_and_after_tests(tmpdir):
Expand Down Expand Up @@ -84,5 +86,24 @@ def test_config_pool_threads(self):
pc = Pinecone(api_key="test-api-key", host="test-controller-host", pool_threads=10)
assert pc.index_api.api_client.pool_threads == 10
idx = pc.Index(host='my-index-host', name='my-index-name')
assert idx._api_client.pool_threads == 10
assert idx._vector_api.api_client.pool_threads == 10

def test_config_when_openapi_config_is_passed_merges_api_key(self):
oai_config = OpenApiConfiguration()
pc = Pinecone(api_key='asdf', openapi_config=oai_config)
assert pc.config.openapi_config.api_key == {'ApiKeyAuth': 'asdf'}

def test_ssl_config_passed_to_index_client(self):
oai_config = OpenApiConfiguration()
oai_config.ssl_ca_cert = 'path/to/cert'
proxy_headers = make_headers(proxy_basic_auth='asdf')
oai_config.proxy_headers = proxy_headers

pc = Pinecone(api_key='key', openapi_config=oai_config)

assert pc.config.openapi_config.ssl_ca_cert == 'path/to/cert'
assert pc.config.openapi_config.proxy_headers == proxy_headers

idx = pc.Index(host='host')
assert idx._vector_api.api_client.configuration.ssl_ca_cert == 'path/to/cert'
assert idx._vector_api.api_client.configuration.proxy_headers == proxy_headers
36 changes: 36 additions & 0 deletions tests/unit/test_config_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
from pinecone.config import ConfigBuilder
from pinecone import PineconeConfigurationError

class TestConfigBuilder:
def test_build_simple(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host")
assert config.api_key == "my-api-key"
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}
assert config.openapi_config.host == "https://my-controller-host"
assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"}

def test_build_merges_key_and_host_when_openapi_config_provided(self):
config = ConfigBuilder.build(
api_key="my-api-key",
host="https://my-controller-host",
openapi_config=OpenApiConfiguration()
)
assert config.api_key == "my-api-key"
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}
assert config.openapi_config.host == "https://my-controller-host"
assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"}

def test_build_errors_when_no_api_key_is_present(self):
with pytest.raises(PineconeConfigurationError) as e:
ConfigBuilder.build()
assert str(e.value) == "You haven't specified an Api-Key."

def test_build_errors_when_no_host_is_present(self):
with pytest.raises(PineconeConfigurationError) as e:
ConfigBuilder.build(api_key='my-api-key')
assert str(e.value) == "You haven't specified a host."
14 changes: 10 additions & 4 deletions tests/unit/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.core.client.models import IndexList, IndexModel
from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration

import time

@pytest.fixture
Expand Down Expand Up @@ -107,25 +109,29 @@ def test_list_indexes_returns_iterable(self, mocker, index_list_response):
response = p.list_indexes()
assert [i.name for i in response] == ["index1", "index2", "index3"]

def test_api_key_and_openapi_config(self, mocker):
p = Pinecone(api_key="123", openapi_config=OpenApiConfiguration.get_default_copy())
assert p.config.api_key == "123"

class TestIndexConfig:
def test_default_pool_threads(self):
pc = Pinecone(api_key="123-456-789")
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 1
assert index._vector_api.api_client.pool_threads == 1

def test_pool_threads_when_indexapi_passed(self):
pc = Pinecone(api_key="123-456-789", pool_threads=2, index_api=ManageIndexesApi())
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 2
assert index._vector_api.api_client.pool_threads == 2

def test_target_index_with_pool_threads_inherited(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10, foo='bar')
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 10
assert index._vector_api.api_client.pool_threads == 10

def test_target_index_with_pool_threads_kwarg(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10)
index = pc.Index(host='my-host.svg.pinecone.io', pool_threads=5)
assert index._api_client.pool_threads == 5
assert index._vector_api.api_client.pool_threads == 5


22 changes: 11 additions & 11 deletions tests/unit/test_index_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ class TestIndexClientInitialization():
def test_no_additional_headers_leaves_useragent_only(self, additional_headers):
pc = Pinecone(api_key='YOUR_API_KEY')
index = pc.Index(host='myhost', additional_headers=additional_headers)
assert len(index._api_client.default_headers) == 1
assert 'User-Agent' in index._api_client.default_headers
assert 'python-client-' in index._api_client.default_headers['User-Agent']
assert len(index._vector_api.api_client.default_headers) == 1
assert 'User-Agent' in index._vector_api.api_client.default_headers
assert 'python-client-' in index._vector_api.api_client.default_headers['User-Agent']

def test_additional_headers_one_additional(self):
pc = Pinecone(api_key='YOUR_API_KEY')
index = pc.Index(
host='myhost',
additional_headers={'test-header': 'test-header-value'}
)
assert 'test-header' in index._api_client.default_headers
assert len(index._api_client.default_headers) == 2
assert 'test-header' in index._vector_api.api_client.default_headers
assert len(index._vector_api.api_client.default_headers) == 2

def test_multiple_additional_headers(self):
pc = Pinecone(api_key='YOUR_API_KEY')
Expand All @@ -34,9 +34,9 @@ def test_multiple_additional_headers(self):
'test-header2': 'test-header-value2'
}
)
assert 'test-header' in index._api_client.default_headers
assert 'test-header2' in index._api_client.default_headers
assert len(index._api_client.default_headers) == 3
assert 'test-header' in index._vector_api.api_client.default_headers
assert 'test-header2' in index._vector_api.api_client.default_headers
assert len(index._vector_api.api_client.default_headers) == 3

def test_overwrite_useragent(self):
# This doesn't seem like a common use case, but we may want to allow this
Expand All @@ -48,6 +48,6 @@ def test_overwrite_useragent(self):
'User-Agent': 'test-user-agent'
}
)
assert len(index._api_client.default_headers) == 1
assert 'User-Agent' in index._api_client.default_headers
assert index._api_client.default_headers['User-Agent'] == 'test-user-agent'
assert len(index._vector_api.api_client.default_headers) == 1
assert 'User-Agent' in index._vector_api.api_client.default_headers
assert index._vector_api.api_client.default_headers['User-Agent'] == 'test-user-agent'