Skip to content

Commit

Permalink
Add Fireworks partner packages
Browse files Browse the repository at this point in the history
  • Loading branch information
benjibc committed Feb 18, 2024
1 parent d7c26c8 commit 86eeeca
Show file tree
Hide file tree
Showing 23 changed files with 2,016 additions and 0 deletions.
1 change: 1 addition & 0 deletions libs/partners/fireworks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
21 changes: 21 additions & 0 deletions libs/partners/fireworks/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 LangChain, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
59 changes: 59 additions & 0 deletions libs/partners/fireworks/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests

# Default target executed when no arguments are given to make.
all: help

# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

test:
poetry run pytest $(TEST_FILE)

tests:
poetry run pytest $(TEST_FILE)


######################
# LINTING AND FORMATTING
######################

# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/fireworks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_fireworks
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)

spell_check:
poetry run codespell --toml pyproject.toml

spell_fix:
poetry run codespell --toml pyproject.toml -w

check_imports: $(shell find langchain_fireworks -name '*.py')
poetry run python ./scripts/check_imports.py $^

######################
# HELP
######################

help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'
3 changes: 3 additions & 0 deletions libs/partners/fireworks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# LangChain-Fireworks

This is the partner package for tying Fireworks.ai and LangChain together
9 changes: 9 additions & 0 deletions libs/partners/fireworks/langchain_fireworks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from langchain_fireworks.embeddings import FireworksEmbeddings
from langchain_fireworks.llms import Fireworks
from langchain_fireworks.version import __version__

__all__ = [
"__version__",
"Fireworks",
"FireworksEmbeddings",
]
50 changes: 50 additions & 0 deletions libs/partners/fireworks/langchain_fireworks/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
from typing import Any, Dict, List

import fireworks # type: ignore
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str


class FireworksEmbeddings(BaseModel, Embeddings):
"""FireworksEmbeddings embedding model.
Example:
.. code-block:: python
from langchain_fireworks import FireworksEmbeddings
model = FireworksEmbeddings(
model='nomic-ai/nomic-embed-text-v1.5'
)
"""

_client: fireworks.Fireworks
fireworks_api_key: SecretStr = convert_to_secret_str("")
model: str

@root_validator()
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate environment variables."""
fireworks_api_key = convert_to_secret_str(
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
)
values["fireworks_api_key"] = fireworks_api_key

# note this sets it globally for module
# there isn't currently a way to pass it into client
fireworks.api_key = fireworks_api_key.get_secret_value()
values["_client"] = fireworks.Fireworks()
return values

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data
]

def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
return self.embed_documents([text])[0]
205 changes: 205 additions & 0 deletions libs/partners/fireworks/langchain_fireworks/llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Wrapper around Fireworks AI's Completion API."""

import logging
from typing import Any, Dict, List, Optional

import requests
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_fireworks.version import __version__

logger = logging.getLogger(__name__)


class Fireworks(LLM):
"""LLM models from `Fireworks`.
To use, you'll need an API key which you can find here:
https://fireworks.ai This can be passed in as init param
``fireworks_api_key`` or set as environment variable ``FIREWORKS_API_KEY``.
Fireworks AI API reference: https://readme.fireworks.ai/
"""

base_url: str = "https://api.fireworks.ai/inference/v1"
"""Base inference API URL."""
fireworks_api_key: SecretStr
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
model: str
"""Model name. Available models listed here:
https://readme.fireworks.ai/
"""
temperature: Optional[float] = None
"""Model temperature."""
top_p: Optional[float] = None
"""Used to dynamically adjust the number of choices for each predicted token based
on the cumulative probabilities. A value of 1 will always yield the same
output. A temperature less than 1 favors more correctness and is appropriate
for question answering or summarization. A value greater than 1 introduces more
randomness in the output.
"""
top_k: Optional[int] = None
"""Used to limit the number of choices for the next predicted word or token. It
specifies the maximum number of tokens to consider at each step, based on their
probability of occurrence. This technique helps to speed up the generation
process and can improve the quality of the generated text by focusing on the
most likely options.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
repetition_penalty: Optional[float] = None
"""A number that controls the diversity of generated text by reducing the
likelihood of repeated sequences. Higher values decrease repetition.
"""
logprobs: Optional[int] = None
"""An integer that specifies how many top token log probabilities are included in
the response for each token generation step.
"""

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
return values

@property
def _llm_type(self) -> str:
"""Return type of model."""
return "fireworks"

def _format_output(self, output: dict) -> str:
return output["output"]["choices"][0]["text"]

@staticmethod
def get_user_agent() -> str:
return f"langchain-fireworks/{__version__}"

@property
def default_params(self) -> Dict[str, Any]:
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"repetition_penalty": self.repetition_penalty,
}

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Fireworks's text generation endpoint.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model..
"""

headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}

# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
response = requests.post(url=self.base_url, json=payload, headers=headers)

if response.status_code >= 500:
raise Exception(f"Fireworks Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"Fireworks returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)

data = response.json()
if data.get("status") != "finished":
err_msg = data.get("error", "Undefined Error")
raise Exception(err_msg)

output = self._format_output(data)

return output

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Fireworks model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
"""
headers = {
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
"Content-Type": "application/json",
}
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
payload: Dict[str, Any] = {
**self.default_params,
"prompt": prompt,
"stop": stop_to_use,
**kwargs,
}

# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
async with ClientSession() as session:
async with session.post(
self.base_url, json=payload, headers=headers
) as response:
if response.status >= 500:
raise Exception(f"Fireworks Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"Fireworks received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"Fireworks returned an unexpected response with status "
f"{response.status}: {response.text}"
)

response_json = await response.json()

if response_json.get("status") != "finished":
err_msg = response_json.get("error", "Undefined Error")
raise Exception(err_msg)

output = self._format_output(response_json)
return output
Empty file.
8 changes: 8 additions & 0 deletions libs/partners/fireworks/langchain_fireworks/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Main entrypoint into package."""
from importlib import metadata

try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""

0 comments on commit 86eeeca

Please sign in to comment.