Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

outlines.generate.choice generates tkens other than provided choices - special tokens being added to tokenizer incorrectly? #893

Open
aaronsnoswell opened this issue May 16, 2024 · 11 comments
Labels

Comments

@aaronsnoswell
Copy link

Describe the issue as clearly as possible:

With some models, outlines.generate.choice is leading to answers being generated which aren't one of the choices provided to outlines.generate.choice. This seems to only occur for some models, and when this issue occurs, I also see a warning from HF transformers;

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

An MWE is attached below;

Steps/code to reproduce the bug:

import torch
import outlines
from outlines import samplers

rng = torch.Generator(device="cuda")
rng.manual_seed(1337)

# Generated outputs match the provided choices
#model_path = "distilbert/distilgpt2"

# Generated outputs are not in the set of chocies
# Also get a warning ''Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
#model_path = "meta-llama/Meta-Llama-3-8B"
model_path = "EleutherAI/pythia-1b-deduped"

model = outlines.models.transformers(model_path, device="cuda")
model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.choice(model, ["-1", "0", "1"], sampler)

prompt = """Give me an integer ranging from -1 to 1 inclusive..."""

for i in range(10):
    answer = generator(prompt, rng=rng)
    print(answer)

Expected result:

# Something like the following
-1
0
0
-1
1
1
0
1
-1
1

Error message:

# The actual generated output varies based on the model, but e.g. with `EleutherAI/pythia-1b-deduped`, I get;
+/
.
.
+/
/
/
+/
+/
/
+/

Outlines/Python version information:

(brix) C:\Development>python -c "from outlines import _version; print(_version.version)"
0.0.41

(brix) C:\Development>python -c "import sys; print('Python', sys.version)"
Python 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 17:28:07) [MSC v.1916 64 bit (AMD64)]

(brix) C:\Development>pip freeze
accelerate==0.30.1
aiohttp==3.9.5
aiosignal==1.3.1
anaconda-anon-usage @ file:///C:/b/abs_c3w_h1zzjg/croot/anaconda-anon-usage_1710965204622/work
annotated-types==0.6.0
anyio==4.3.0
archspec @ file:///croot/archspec_1709217642129/work
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
autoflake==2.3.1
Babel==2.15.0
beautifulsoup4==4.12.3
black==24.4.2
bleach==6.1.0
blinker==1.8.2
boltons @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/boltons_1699480450092/work
Brotli @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/brotli-split_1699473013692/work
certifi @ file:///C:/b/abs_35d7n66oz9/croot/certifi_1707229248467/work/certifi
cffi @ file:///C:/b/abs_924gv1kxzj/croot/cffi_1700254355075/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click==8.1.7
cloudpickle==3.0.0
colorama @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/colorama_1699472650914/work
comm==0.2.2
conda @ file:///C:/b/abs_1e6dlkntna/croot/conda_1710772093015/work
conda-content-trust @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-content-trust_1699553484152/work
conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1706733287605/work/src
conda-package-handling @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-package-handling_1699480603217/work
conda_package_streaming @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-package-streaming_1699475879769/work
cryptography @ file:///C:/b/abs_f5n93r0tun/croot/cryptography_1710350404202/work
datasets==2.19.1
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distro @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/distro_1701796812765/work
dnspython==2.6.1
email_validator==2.1.1
executing==2.0.1
Faker==25.1.0
fakeredis==2.23.0
fastapi==0.111.0
fastapi-cli==0.0.3
fastjsonschema==2.19.1
filelock==3.14.0
Flask==3.0.3
Flask-Cors==4.0.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.0
idna @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/idna_1699473483982/work
iniconfig==2.0.0
intel-openmp==2021.4.0
interegular==0.3.3
ipykernel==6.29.4
ipython==8.24.0
ipywidgets==8.1.2
isoduration==20.11.0
isort==5.13.2
itsdangerous==2.2.0
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpatch @ file:///C:/b/abs_d3zr1enxou/croot/jsonpatch_1710807549298/work
jsonpointer==2.1
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.2.0
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
lark==1.1.9
libmambapy @ file:///C:/b/abs_7dmjutgtwb/croot/mamba-split_1712091963973/work/libmambapy
llvmlite==0.42.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mdurl==0.1.2
menuinst @ file:///C:/b/abs_099kybla52/croot/menuinst_1706732987063/work
mistune==3.0.2
mkl==2021.4.0
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.2.1
nltk==3.8.1
notebook_shim==0.2.4
numba==0.59.1
numpy==1.26.4
openai==1.28.1
orjson==3.10.3
outlines==0.0.41
overrides==7.7.0
packaging @ file:///C:/b/abs_cc1h2xfosn/croot/packaging_1710807447479/work
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pathspec==0.12.1
pillow==10.2.0
platformdirs @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/platformdirs_1701797392447/work
pluggy==1.5.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
pure-eval==0.2.2
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycosat @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/pycosat_1699482932804/work
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pydantic==2.7.1
pydantic_core==2.18.2
pyflakes==3.2.0
Pygments==2.18.0
PySocks @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/pysocks_1699473336188/work
pytest==8.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
python-multipart==0.0.9
pytz==2024.1
pywin32==306
pywinpty==2.0.13
PyYAML==6.0.1
pyzmq==26.0.3
ranking_challenge==1.0.3
redis==5.0.4
referencing==0.35.1
regex==2024.5.10
requests @ file:///C:/b/abs_474vaa3x9e/croot/requests_1707355619957/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
ruamel.yaml @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/ruamel.yaml_1699483184324/work
safetensors==0.4.3
scikit-learn==1.4.2
scipy==1.13.0
Send2Trash==1.8.3
setuptools==68.2.2
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.37.2
sympy==1.12
tbb==2021.11.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.3.0
tokenizers==0.19.1
torch==2.3.0+cu118
torchaudio==2.3.0+cu118
torchvision==0.18.0+cu118
tornado==6.4
tqdm @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/tqdm_1701808178601/work
traitlets==5.14.3
transformers==4.40.2
truststore @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/truststore_1701881385424/work
typer==0.12.3
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tzdata==2024.1
ujson==5.9.0
uri-template==1.3.0
urllib3 @ file:///C:/b/abs_4etpfrkumr/croot/urllib3_1707770616184/work
uvicorn==0.29.0
watchfiles==0.21.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
websockets==12.0
Werkzeug==3.0.3
wheel==0.41.2
widgetsnbextension==4.0.10
win-inet-pton @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/win_inet_pton_1699472992992/work
xxhash==3.4.1
yarl==1.9.4
zstandard==0.19.0

Context for the issue:

No response

@aaronsnoswell
Copy link
Author

At the suggestion of folks in the discord, I tried cloning the main branch and using that instead of my pip install outlines.

Can confirm the bug still occurs there.

@aaronsnoswell
Copy link
Author

(as in, this commit; 78852b0)

@isamu-isozaki
Copy link
Contributor

I tried your code in the main branch using

pip uninstall outlines
pip install git+https://github.com/outlines-dev/outlines.git@main

and I got

-1
0
0
1
-1
-1
1
1
1
-1

@isamu-isozaki
Copy link
Contributor

!python -c "from outlines import _version; print(_version.version)"
0.0.43.dev11+g78852b0
!python -c "import sys; print('Python', sys.version)"
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:40:08) [MSC v.1938 64 bit (AMD64)]
!pip freeze
accelerate==0.29.3
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
auto_gptq==0.7.1
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
Brotli @ file:///D:/bld/brotli-split_1695989908365/work
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
cffi==1.16.0
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
comm==0.2.2
cramjam==2.8.3
dataclasses-json==0.6.4
datasets==2.19.0
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
exceptiongroup==1.2.1
executing==2.0.1
exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu118-cp310-cp310-win_amd64.whl#sha256=5545d7ff9e31c0e7fb8667b36ac55c28c89c396438b9b7be287777ad33b9a157
fastapi==0.110.2
fastjsonschema==2.19.1
fastparquet==2024.2.0
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1712686151958/work
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
gekko==1.1.1
greenlet==3.0.3
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1713279365350/work
intel-openmp==2021.4.0
interegular==0.3.3
ipykernel==6.29.4
ipython==8.24.0
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.19.1
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work
joblib==1.4.2
json5==0.9.25
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
langchain==0.1.16
langchain-community==0.0.34
langchain-core==0.1.46
langchain-text-splitters==0.0.1
langsmith==0.1.51
lark==1.1.9
llvmlite==0.42.0
MarkupSafe @ file:///D:/bld/markupsafe_1706900062361/work
marshmallow==3.21.1
matplotlib-inline==0.1.7
mistune==3.0.2
mkl==2021.4.0
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.4
nest-asyncio==1.6.0
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1712540363324/work
ninja==1.11.1.1
notebook==7.1.3
notebook_shim==0.2.4
numba==0.59.1
numpy @ file:///D:/bld/numpy_1707225570061/work/dist/numpy-1.26.4-cp310-cp310-win_amd64.whl#sha256=6761da75b1528684e6bf4dabdbdded9d1eb4d0e9b299482c7ce152cfb3155106
openai==1.23.6
orjson==3.10.1
outlines @ git+https://github.com/outlines-dev/outlines.git@78852b0169e7c4c6f3eaf6b2b2e6209e41edf98c
overrides==7.7.0
packaging==23.2
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
peft==0.10.0
pillow @ file:///D:/bld/pillow_1712154657455/work
platformdirs==4.2.1
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
pure-eval==0.2.2
pyairports==2.1.1
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycountry==23.12.11
pycparser==2.22
pydantic==2.7.1
pydantic_core==2.18.2
Pygments==2.17.2
PySocks @ file:///D:/bld/pysocks_1661604991356/work
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.1
pywin32==306
pywinpty==2.0.13
PyYAML @ file:///D:/bld/pyyaml_1695373629531/work
pyzmq==26.0.2
qtconsole==5.5.1
QtPy==2.4.1
referencing==0.35.0
regex==2024.4.16
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rouge==1.0.1
rpds-py==0.18.0
safetensors==0.4.3
scipy==1.13.0
Send2Trash==1.8.3
sentencepiece==0.2.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
SQLAlchemy==2.0.29
sseclient==0.0.27
stack-data==0.6.3
starlette==0.37.2
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180539862/work
tbb==2021.12.0
tenacity==8.2.3
terminado==0.18.1
tiktoken==0.6.0
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.3.0
torchaudio==2.3.0
torchvision==0.18.0
tornado==6.4
tqdm==4.66.2
traitlets==5.14.3
transformers @ git+https://github.com/huggingface/transformers@e0c3cee17085914bbe505c159beeb8ae39bc37dd
types-python-dateutil==2.9.0.20240316
typing-inspect==0.9.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1712329955671/work
tzdata==2024.1
uri-template==1.3.0
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1708239446578/work
uvicorn==0.29.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
websockets==12.0
widgetsnbextension==4.0.10
win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work
xxhash==3.4.1
yarl==1.9.4

@GurvanR
Copy link

GurvanR commented May 17, 2024

Hello, I have the same issue of wrong token generations. I'm using the vllm serve and to that end I installed with
pip install outlines[serve]

It is working well with OPT models but here is what I have with other models:
(in parentheses are the tokens it generated, the expected tokens are 'A', 'B', 'C' or 'D'.)

  • Aquila models: 'et' or ''
  • Baichuan: '<reserved_161>' '<reserved_258>'
  • Qwen: ImportError: This modeling file requires the following packages that were not found in your environment: tiktoken. Run pip install tiktoken
  • Mistral: ' -'
  • ChatGLM: AssertionError: is not a special token for GLMTokenizer
  • DeciLM:
    • 7B: ' -'
    • 6B: 'ie', 'ie', 'eg', 'eg', 'ber'

Note that all these models are working well with vLLM.

So my question is probably how can I transfer the trick of installing the main with git+ command with the vllm serve ?

thank you all !

@wjn0
Copy link

wjn0 commented May 20, 2024

I'm still seeing this on 7863f8e with hf-internal-testing/tiny-random-LlamaForCausalLM.

Possibly relevant warning (but it is not resolved by manually setting pad_token_id to e.g. eos_token_id):

UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pas_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values.

Repro script (modified from above):

import torch
import outlines
from outlines import samplers

rng = torch.Generator()
rng.manual_seed(1337)

# Generated outputs match the provided choices
#model_path = "distilbert/distilgpt2"

# Generated outputs are not in the set of chocies
# Also get a warning ''Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
#model_path = "meta-llama/Meta-Llama-3-8B"
#model_path = "EleutherAI/pythia-1b-deduped"
model_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"

model = outlines.models.transformers(model_path)
# model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.choice(model, [",", "\n"], sampler)

prompt = """Give me an integer ranging from -1 to 1 inclusive..."""

for i in range(10):
    answer = generator(prompt, rng=rng)
    print(answer)

Can further confirm that the above test script, unmodified, works fine with v0.0.39 (regardless of whether pad_token_id is manually set or not).

@rlouf
Copy link
Member

rlouf commented May 22, 2024

Pinging @lapp0

@lapp0
Copy link
Collaborator

lapp0 commented May 22, 2024

I cannot reproduce, here's what I get:

-1
0
0
1
-1
-1
1
1
1
-1

Could you please check what output you get for just outlines.generate.text?

Code

import torch
import outlines
from outlines import samplers

rng = torch.Generator(device="cuda")
rng.manual_seed(1337)

model_path = "EleutherAI/pythia-1b-deduped"

model = outlines.models.transformers(model_path, device="cuda")
model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.text(model)

prompt = """Some numbers: -1, 0, 1, -1, 0, 1, -1, 0, 1,"""

answer = generator(prompt, rng=rng, max_tokens=30)
print(answer)

My outlines.generate.text() Output:

 -1. But we have enough energy sets to cover the moon!" He trotted down the long hallway, his footsteps echoing.
 Est
1

@br3no
Copy link
Contributor

br3no commented May 28, 2024

I strongly believe this is an issue with the state-machine cache that was fixed with this PR: #911

@brandonwillard, what do you think?

@aaronsnoswell
Copy link
Author

aaronsnoswell commented May 28, 2024 via email

@brandonwillard
Copy link
Contributor

Would it help for me to pull PR #911 and test at my end?

It's already merged into main, so you can check that out and try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Todo
Development

No branches or pull requests

8 participants