-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jamba #29943
Add jamba #29943
Conversation
…lerance. left padding numerical difference are accentuated by mamba layers
Reviewing ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! 🔥 it's already super transformers like!
- tokenization_auto needs to be updated to include which tokenizer
jamba
uses! - a few code paths that would be nice to remove BUT that would mean having to convert the checkpoints, while your naming choices are alright. That would be a bit annoying.
- great PR! 🤗
README.md
Outdated
@@ -397,6 +397,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h | |||
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever. | |||
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang. | |||
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi. | |||
1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from <FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fill!
docs/source/en/model_doc/jamba.md
Outdated
@@ -0,0 +1,129 @@ | |||
<!--Copyright 2022 The HuggingFace Team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<!--Copyright 2022 The HuggingFace Team. All rights reserved. | |
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU. | ||
|
||
As depicted in the diagram below, Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers. | ||
|
||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/jamba_architecture.png" | ||
alt="drawing" width="600"/> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice 🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙂
docs/source/en/model_doc/jamba.md
Outdated
You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model. | ||
|
||
### Run the model | ||
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture. |
docs/source/en/model_doc/jamba.md
Outdated
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", | ||
trust_remote_code=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", | |
trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") |
if self._attn_implementation == "flash_attention_2": | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
elif self._attn_implementation == "sdpa" and not output_attentions: | ||
# output_attentions=True can not be supported when using SDPA, and we fall back on | ||
# the manual implementation that requires a 4D causal mask in all cases. | ||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
) | ||
else: | ||
# 4d mask is passed through the layers | ||
attention_mask = _prepare_4d_causal_attention_mask( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
sliding_window=self.config.sliding_window, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about 4d mask! but it can be updated in another PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well I feel I'm missing something.. what comment are we talking about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh probably about the cache_positions
🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's about using _update_causal_mask
from gemma
modelling code that simplifies the whole logic a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, pushing for this, the _prepare_4d_causal_attention_mask
and etc are just too scattered, and will be deprecated!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(No need for the cache position you can pass the past_length
if calc_logits_for_entire_prompt: | ||
logits = self.lm_head(hidden_states) | ||
else: | ||
logits = self.lm_head(hidden_states[..., -1:, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm could you explain the motivations behind this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
it's pretty much explained in the docstring for calc_logits_for_entire_prompt
in configuration_jamba.py
. For long sequences, the logits can take a lot of GPU memory, especially as they are saved in FP32. So for a prompt of 128K tokens, with our vocab size of 64K, only the logits for the prompt take 32GB of GPU memory (128K64K4). The thing is that in order to generate from the model, we don't need all the prompt logits - just those of the last token. Anyway the GenerationMixin takes only the logits of the last prompt token (next_token_logits = outputs.logits[:, -1, :]
appears many times in src/transformers/generation/utils.py
). So we want to save all this unnecessary memory and compute only the logits we need.
Honestly, we were a bit surprised to see that the standard in transformers
is to calculate the logits for the entire prompt when generating. I understand that for relatively short prompts this doesn't add up to a lot of extra memory, but for long prompt it's a complete waste.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's where I say, when we generate with transformers, we only pass a single input ids starting from the second forward pass. Which is why we never need this, the hidden states generated after the first forward pass are always of shape 1 in sequence length!
This can be safely removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry you mean the first forward as well. Not a fan of having code paths for this + as @gante said, assisted generation will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was talking about the first forward pass.
RE assisted generation - you're right and that's why we kept this as a config option. If a user wants to use assisted generation, they can set calc_logits_for_entire_prompt
to True in the config and everything will work.
As you saw, part of the Jamba promise is to be able to fit long sequences (~140K) on a single 80GB gpu (with int8 weight quantization). If we calculate the logits for the entire propmt, that's not possible. That's the reason we feel that by default entire prompt logits shouldn't be calculated during generation. If the user wants/needs that, they can do that by setting the appropriate attribute in the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. Let's leave it for now, this will be streamlined to generate
as this can be a generation config argument set if you use use_cache
and no assisted_decoding
.
Will be deprecated in a near futur, as this is mostly for inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: if we make the flag an integer, e.g. num_logits_to_keep: Optional[int]
, then it can easily become compatible with assisted generation
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gante - just to make sure I understand: If we do that, we'll still need to modify the assisted generation code to set num_logits_to_keep
to candidate_length
correct?
@@ -0,0 +1,830 @@ | |||
# coding=utf-8 | |||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
|
||
|
||
@require_torch | ||
@unittest.skip("Update once we have a tiny Jamba model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great TODO! Tiny logits would be awesome!
- expected loss to make sure we compute the same!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only thing left here!
…rnorms and do it directly in the forward pass
…ions only if not None.
…ers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel
…result.router_logits now holds results only for expert layers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks a lot cleaner
n_ctx (`int`, *optional*, defaults to 262144): | ||
This value doesn't have any real effect. The maximum sequence length that this model is intended to be | ||
used with. It can be used with longer sequences, but performance may degrade. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's a very good point. Let's us max_position_embeddings
and also include in the comment that it's use for evaluating! (it's not doing nothing!)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
if self.attention_layer_idx is not None and layer_idx == self.attention_layer_idx: | ||
self._seen_tokens += key_states.shape[-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we no longer use the self._seen_tokens
arg, and rely on cache_positions
, should simplify things
if self._attn_implementation == "flash_attention_2": | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
elif self._attn_implementation == "sdpa" and not output_attentions: | ||
# output_attentions=True can not be supported when using SDPA, and we fall back on | ||
# the manual implementation that requires a 4D causal mask in all cases. | ||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
) | ||
else: | ||
# 4d mask is passed through the layers | ||
attention_mask = _prepare_4d_causal_attention_mask( | ||
attention_mask, | ||
(batch_size, seq_length), | ||
inputs_embeds, | ||
past_key_values_length, | ||
sliding_window=self.config.sliding_window, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(No need for the cache position you can pass the past_length
if calc_logits_for_entire_prompt: | ||
logits = self.lm_head(hidden_states) | ||
else: | ||
logits = self.lm_head(hidden_states[..., -1:, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. Let's leave it for now, this will be streamlined to generate
as this can be a generation config argument set if you use use_cache
and no assisted_decoding
.
Will be deprecated in a near futur, as this is mostly for inference
"use_cache": kwargs.get("use_cache"), | ||
"attention_mask": attention_mask, | ||
"output_router_logits": output_router_logits, | ||
"calc_logits_for_entire_prompt": self.config.calc_logits_for_entire_prompt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if use_cache
this could always be set to False
if there is no self.generation_config.assistant
could be handled here
…he model is too big to download (in docstring of JambaForCausalLM.forward)
past_seen_tokens = ( | ||
past_key_values.get_seq_length() | ||
past_key_values.get_seq_length(self.config.layers_block_type.index("attention")) | ||
if isinstance(past_key_values, HybridMambaAttentionDynamicCache) | ||
else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually we should assume cache position are passed for this mode
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
if len(self.key_cache) <= layer_idx: | ||
return 0 | ||
if self.layers_block_type[layer_idx] == "mamba": | ||
raise ValueError("Can't return seq_length from Mamba layers cache as it doesn't have a sequence length dimension.") | ||
return self.key_cache[layer_idx].shape[-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would rather not have this, cache positions SHOULD be passed. That is here in llama for legacy
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | ||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") | ||
|
||
@classmethod | ||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": | ||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for both
All the rest you added LGTM |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model
…h size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 small comments and should be good to go!
…. Adjust test (test_decoder_model_past_with_large_inputs) accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 Great work everyone!
* Add jamba arch * apply "make fix-copies" changes * fix link to model in JambaConfig docstring * Add n_ctx in modeling file because repo-consistency wants that * Add jamba to flash attention and sdpa documentation * mamba dt_proj quant fix now works for LoRA as well * override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers * add jamba to tokenization auto * fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24) * simple PR fixes * remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer * remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (huggingface/peft#1530) * Add copied comment on JambaMLP (it's the same as MixtralMLP) * remove padding_mask warnings. It's not supported anymore * fix docstring. Float instead of int * A few more minor PR fixes * (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass * Return None attention weights from mamba layers. Append to all attentions only if not None. * remove some leftover jamba archive lists * Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel * no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers * Add Jamba paper on READMEs * (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes) * Add copied from comment * remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms * clearer docstring for _convert_to_standard_cache * style fixes * Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs * rename test so it still overrides what its meant to override * draft * oups * nit * remove more complexe logic * fix names used in config * fix fix fix * style * fix some more failing tests * generate did not init the cache 🙃 * more small nits * typo * config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes * fix init of pkv with torch.tensor() * empty tensor * fix some init issues * stupid changes required by generate because it does not even support it's own DynamicCache class * more fixes * fix general assisted gen cache_position bug * tests passing * Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py * fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache * no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore * fix docstrings and typehints for past_key_values * style fixes * fix docs * change typehint due to copy from Mixtral * forgot import * import order * Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward) * Add integration test with tiny tandom Jamba model on hub * fix flash attention cache shapes * bring back forgotten hidden states * rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model * align integration test after modeling fixes * bugfix - mamba can use precomputed states only of forward pass is on a single token * bugfix - mamba can use precomputed states only if they match the batch size * typo * remove making _prepare_4d_causal_attention_mask a leaf function * stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
What does this PR do?
Add support for the Jamba architecture by AI21 Labs
Who can review?
@ArthurZucker @younesbelkada