Skip to content

Commit

Permalink
Allow clients to tag requests with a source_tag (#324)
Browse files Browse the repository at this point in the history
## Problem
Need to allow clients to include a `source_tag` to identify the source
of their requests.

## Solution
Allow clients to specify a `source_tag` field in the client constructor,
that will be used to identify the traffic source, if applicable.

Example:
```python
from pinecone import Pinecone

pc = Pinecone(api_key='foo', source_tag='bar')

pc.list_indexes()
```

This would cause the user-agent to get a value like:
```
User-Agent: 'python-client-3.1.0 (urllib3:2.0.7); source_tag=bar'
```

## Type of Change

- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

## Testing

- [ ] Tests are passing
- [ ] Verify source_tag included in user-agent in control plane and data
plane (REST and gRPC)
  • Loading branch information
ssmith-pc committed Mar 22, 2024
1 parent ed8c2ab commit 4fd2d20
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 22 deletions.
5 changes: 4 additions & 1 deletion pinecone/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pinecone.config.openapi import OpenApiConfigFactory
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
from pinecone.utils import normalize_host
from pinecone.utils.constants import SOURCE_TAG

class Config(NamedTuple):
api_key: str = ""
Expand All @@ -14,6 +15,7 @@ class Config(NamedTuple):
ssl_ca_certs: Optional[str] = None
ssl_verify: Optional[bool] = None
additional_headers: Optional[Dict[str, str]] = {}
source_tag: Optional[str] = None

class ConfigBuilder:
"""
Expand Down Expand Up @@ -46,13 +48,14 @@ def build(
api_key = api_key or kwargs.pop("api_key", None) or os.getenv("PINECONE_API_KEY")
host = host or kwargs.pop("host", None)
host = normalize_host(host)
source_tag = kwargs.pop(SOURCE_TAG, None)

if not api_key:
raise PineconeConfigurationError("You haven't specified an Api-Key.")
if not host:
raise PineconeConfigurationError("You haven't specified a host.")

return Config(api_key, host, proxy_url, proxy_headers, ssl_ca_certs, ssl_verify, additional_headers)
return Config(api_key, host, proxy_url, proxy_headers, ssl_ca_certs, ssl_verify, additional_headers, source_tag)

@staticmethod
def build_openapi_config(
Expand Down
1 change: 1 addition & 0 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,5 +633,6 @@ def Index(self, name: str = '', host: str = '', **kwargs):
api_key=api_key,
pool_threads=pt,
openapi_config=openapi_config,
source_tag=self.config.source_tag,
**kwargs
)
4 changes: 3 additions & 1 deletion pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import MAX_MSG_SIZE, REQUEST_ID, CLIENT_VERSION
from pinecone.utils.user_agent import get_user_agent_grpc
from pinecone.exceptions import PineconeException

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,7 +78,8 @@ def __init__(
}
)

self._channel = channel or self._gen_channel()
options = {"grpc.primary_user_agent": get_user_agent_grpc(config)}
self._channel = channel or self._gen_channel(options=options)
self.stub = self.stub_class(self._channel)

@property
Expand Down
17 changes: 7 additions & 10 deletions pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,10 @@ def Index(self, name: str = '', host: str = '', **kwargs):
if name == '' and host == '':
raise ValueError("Either name or host must be specified")

if host != '':
# Use host if it is provided
config = ConfigBuilder.build(api_key=self.config.api_key, host=host)
return GRPCIndex(index_name=name, config=config, **kwargs)

if name != '':
# Otherwise, get host url from describe_index using the index name
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
config = ConfigBuilder.build(api_key=self.config.api_key, host=index_host)
return GRPCIndex(index_name=name, config=config, **kwargs)
# Use host if it is provided, otherwise get host from describe_index
index_host = host or self.index_host_store.get_host(self.index_api, self.config, name)

config = ConfigBuilder.build(api_key=self.config.api_key,
host=index_host,
source_tag=self.config.source_tag)
return GRPCIndex(index_name=name, config=config, **kwargs)
2 changes: 2 additions & 0 deletions pinecone/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ class NodeType(str, enum.Enum):

REQUIRED_VECTOR_FIELDS = {"id", "values"}
OPTIONAL_VECTOR_FIELDS = {"sparse_values", "metadata"}

SOURCE_TAG = "source_tag"
2 changes: 1 addition & 1 deletion pinecone/utils/setup_openapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def setup_openapi_client(api_klass, config, openapi_config, pool_threads):
configuration=openapi_config,
pool_threads=pool_threads
)
api_client.user_agent = get_user_agent()
api_client.user_agent = get_user_agent(config)
extra_headers = config.additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)
Expand Down
26 changes: 23 additions & 3 deletions pinecone/utils/user_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
import urllib3

from .version import __version__
from .constants import SOURCE_TAG
import re

def get_user_agent():
client_id = f"python-client-{__version__}"
def _build_source_tag_field(source_tag):
# normalize source tag
# 1. Lowercase
# 2. Limit charset to [a-z0-9_ ]
# 3. Trim left/right whitespace
# 4. Condense multiple spaces to one, and replace with underscore
tag = source_tag.lower()
tag = re.sub(r'[^a-z0-9_ ]', '', tag)
tag = tag.strip()
tag = "_".join(tag.split())
return f"{SOURCE_TAG}={tag}"

def _get_user_agent(client_id, config):
user_agent_details = {"urllib3": urllib3.__version__}
user_agent = "{} ({})".format(client_id, ", ".join([f"{k}:{v}" for k, v in user_agent_details.items()]))
return user_agent
user_agent += f"; {_build_source_tag_field(config.source_tag)}" if config.source_tag else ""
return user_agent

def get_user_agent(config):
return _get_user_agent(f"python-client-{__version__}", config)

def get_user_agent_grpc(config):
return _get_user_agent(f"python-client[grpc]-{__version__}", config)
11 changes: 11 additions & 0 deletions tests/unit/test_config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def test_build_merges_key_and_host_when_openapi_config_provided(self):
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}

def test_build_with_source_tag(self):
config = ConfigBuilder.build(
api_key="my-api-key",
host="https://my-controller-host",
source_tag="my-source-tag",
)
assert config.api_key == "my-api-key"
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}
assert config.source_tag == "my-source-tag"

def test_build_errors_when_no_api_key_is_present(self):
with pytest.raises(PineconeConfigurationError) as e:
ConfigBuilder.build()
Expand Down
12 changes: 11 additions & 1 deletion tests/unit/test_control.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pinecone import Pinecone, PodSpec, ServerlessSpec
import re
from pinecone import ConfigBuilder, Pinecone, PodSpec, ServerlessSpec
from pinecone.core.client.models import IndexList, IndexModel
from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
Expand Down Expand Up @@ -41,6 +42,15 @@ def test_overwrite_useragent(self):
assert p.index_api.api_client.default_headers['User-Agent'] == 'test-user-agent'
assert len(p.index_api.api_client.default_headers) == 1

def test_set_source_tag_in_useragent(self):
p = Pinecone(api_key="123-456-789", source_tag="test_source_tag")
assert re.search(r"source_tag=test_source_tag", p.index_api.api_client.user_agent) is not None

def test_set_source_tag_in_useragent_via_config(self):
config = ConfigBuilder.build(api_key='YOUR_API_KEY', host='https://my-host', source_tag='my_source_tag')
p = Pinecone(config=config)
assert re.search(r"source_tag=my_source_tag", p.index_api.api_client.user_agent) is not None

@pytest.mark.parametrize("timeout_value, describe_index_responses, expected_describe_index_calls, expected_sleep_calls", [
# When timeout=None, describe_index is called until ready
(None, [{ "status": {"ready": False}}, {"status": {"ready": True}}], 2, 1),
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/test_index_initialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pinecone import Pinecone
import re
from pinecone import ConfigBuilder, Pinecone

class TestIndexClientInitialization():
@pytest.mark.parametrize(
Expand Down Expand Up @@ -50,4 +51,15 @@ def test_overwrite_useragent(self):
)
assert len(index._vector_api.api_client.default_headers) == 1
assert 'User-Agent' in index._vector_api.api_client.default_headers
assert index._vector_api.api_client.default_headers['User-Agent'] == 'test-user-agent'
assert index._vector_api.api_client.default_headers['User-Agent'] == 'test-user-agent'

def test_set_source_tag(self):
pc = Pinecone(api_key="123-456-789", source_tag="test_source_tag")
index = pc.Index(host='myhost')
assert re.search(r"source_tag=test_source_tag", pc.index_api.api_client.user_agent) is not None

def test_set_source_tag_via_config(self):
config = ConfigBuilder.build(api_key='YOUR_API_KEY', host='https://my-host', source_tag='my_source_tag')
pc = Pinecone(config=config)
index = pc.Index(host='myhost')
assert re.search(r"source_tag=my_source_tag", pc.index_api.api_client.user_agent) is not None
11 changes: 11 additions & 0 deletions tests/unit/utils/test_setup_openapi_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import re
from pinecone.config import ConfigBuilder
from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.utils.setup_openapi_client import setup_openapi_client

class TestSetupOpenAPIClient():
def test_setup_openapi_client(self):
""
# config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host")
# api_client = setup_openapi_client(ManageIndexesApi, config=config, pool_threads=2)
# # assert api_client.user_agent == "pinecone-python-client/0.0.1"
45 changes: 42 additions & 3 deletions tests/unit/utils/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
import re
from pinecone.utils.user_agent import get_user_agent
from pinecone.utils.user_agent import get_user_agent, get_user_agent_grpc
from pinecone.config import ConfigBuilder

class TestUserAgent():
def test_user_agent(self):
useragent = get_user_agent()
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host")
useragent = get_user_agent(config)
assert re.search(r"python-client-\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"urllib3:\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"urllib3:\d+\.\d+\.\d+", useragent) is not None

def test_user_agent_with_source_tag(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag="my_source_tag")
useragent = get_user_agent(config)
assert re.search(r"python-client-\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"urllib3:\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"source_tag=my_source_tag", useragent) is not None

def test_source_tag_is_normalized(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag="my source tag!!!!")
useragent = get_user_agent(config)
assert re.search(r"source_tag=my_source_tag", useragent) is not None

config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag="My Source Tag")
useragent = get_user_agent(config)
assert re.search(r"source_tag=my_source_tag", useragent) is not None

config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag=" My Source Tag 123 ")
useragent = get_user_agent(config)
assert re.search(r"source_tag=my_source_tag_123", useragent) is not None

config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag=" My Source Tag 123 #### !! ")
useragent = get_user_agent(config)
assert re.search(r"source_tag=my_source_tag_123", useragent) is not None

def test_user_agent_grpc(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host")
useragent = get_user_agent_grpc(config)
assert re.search(r"python-client\[grpc\]-\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"urllib3:\d+\.\d+\.\d+", useragent) is not None

def test_user_agent_grpc_with_source_tag(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host", source_tag="my_source_tag")
useragent = get_user_agent_grpc(config)
assert re.search(r"python-client\[grpc\]-\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"urllib3:\d+\.\d+\.\d+", useragent) is not None
assert re.search(r"source_tag=my_source_tag", useragent) is not None
7 changes: 7 additions & 0 deletions tests/unit_grpc/test_grpc_index_initialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from pinecone import ConfigBuilder

class TestGRPCIndexInitialization:
def test_init_with_default_config(self):
Expand Down Expand Up @@ -85,3 +87,8 @@ def test_config_passed_when_target_by_host(self):
# Unset fields still get default values
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.conn_timeout == 1

def test_config_passes_source_tag_when_set(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY', source_tag='my_source_tag')
index = pc.Index(name='my-index', host='host')
assert re.search(r"source_tag=my_source_tag", pc.index_api.api_client.user_agent) is not None

0 comments on commit 4fd2d20

Please sign in to comment.