Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jamba #29943

Merged
merged 78 commits into from
Apr 18, 2024
Merged

Add jamba #29943

merged 78 commits into from
Apr 18, 2024

Conversation

tomeras91
Copy link
Contributor

What does this PR do?

Add support for the Jamba architecture by AI21 Labs

Who can review?

@ArthurZucker @younesbelkada

@ArthurZucker
Copy link
Collaborator

Reviewing !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! 🔥 it's already super transformers like!

  • tokenization_auto needs to be updated to include which tokenizer jamba uses!
  • a few code paths that would be nice to remove BUT that would mean having to convert the checkpoints, while your naming choices are alright. That would be a bit annoying.
  • great PR! 🤗

README.md Outdated
@@ -397,6 +397,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[ImageGPT](https://huggingface.co/docs/transformers/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
1. **[Informer](https://huggingface.co/docs/transformers/model_doc/informer)** (from Beihang University, UC Berkeley, Rutgers University, SEDD Company) released with the paper [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) by Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
1. **[InstructBLIP](https://huggingface.co/docs/transformers/model_doc/instructblip)** (from Salesforce) released with the paper [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
1. **[Jamba](https://huggingface.co/docs/transformers/main/model_doc/jamba)** (from <FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fill!

@@ -0,0 +1,129 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Comment on lines +27 to +32
Jamba is a pretrained, mixture-of-experts (MoE) generative text model, with 12B active parameters and an overall of 52B parameters across all experts. It supports a 256K context length, and can fit up to 140K tokens on a single 80GB GPU.

As depicted in the diagram below, Jamba's architecture features a blocks-and-layers approach that allows Jamba to successfully integrate Transformer and Mamba architectures altogether. Each Jamba block contains either an attention or a Mamba layer, followed by a multi-layer perceptron (MLP), producing an overall ratio of one Transformer layer out of every eight total layers.

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/jamba_architecture.png"
alt="drawing" width="600"/>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice 🔥

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙂

You can run the model not using the optimized Mamba kernels, but it is **not** recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify `use_mamba_kernels=False` when loading the model.

### Run the model
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Please note that, at the moment, `trust_remote_code=True` is required for running the new Jamba architecture.

Comment on lines 56 to 57
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")

Comment on lines 1693 to 1713
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about 4d mask! but it can be updated in another PR!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well I feel I'm missing something.. what comment are we talking about?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh probably about the cache_positions 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's about using _update_causal_mask from gemma modelling code that simplifies the whole logic a lot!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, pushing for this, the _prepare_4d_causal_attention_mask and etc are just too scattered, and will be deprecated!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No need for the cache position you can pass the past_length

Comment on lines 1876 to 1879
if calc_logits_for_entire_prompt:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -1:, :])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm could you explain the motivations behind this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure
it's pretty much explained in the docstring for calc_logits_for_entire_prompt in configuration_jamba.py. For long sequences, the logits can take a lot of GPU memory, especially as they are saved in FP32. So for a prompt of 128K tokens, with our vocab size of 64K, only the logits for the prompt take 32GB of GPU memory (128K64K4). The thing is that in order to generate from the model, we don't need all the prompt logits - just those of the last token. Anyway the GenerationMixin takes only the logits of the last prompt token (next_token_logits = outputs.logits[:, -1, :] appears many times in src/transformers/generation/utils.py). So we want to save all this unnecessary memory and compute only the logits we need.
Honestly, we were a bit surprised to see that the standard in transformers is to calculate the logits for the entire prompt when generating. I understand that for relatively short prompts this doesn't add up to a lot of extra memory, but for long prompt it's a complete waste.

Copy link
Collaborator

@ArthurZucker ArthurZucker Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's where I say, when we generate with transformers, we only pass a single input ids starting from the second forward pass. Which is why we never need this, the hidden states generated after the first forward pass are always of shape 1 in sequence length!
This can be safely removed

Copy link
Collaborator

@ArthurZucker ArthurZucker Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry you mean the first forward as well. Not a fan of having code paths for this + as @gante said, assisted generation will fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was talking about the first forward pass.
RE assisted generation - you're right and that's why we kept this as a config option. If a user wants to use assisted generation, they can set calc_logits_for_entire_prompt to True in the config and everything will work.
As you saw, part of the Jamba promise is to be able to fit long sequences (~140K) on a single 80GB gpu (with int8 weight quantization). If we calculate the logits for the entire propmt, that's not possible. That's the reason we feel that by default entire prompt logits shouldn't be calculated during generation. If the user wants/needs that, they can do that by setting the appropriate attribute in the config

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. Let's leave it for now, this will be streamlined to generate as this can be a generation config argument set if you use use_cache and no assisted_decoding.
Will be deprecated in a near futur, as this is mostly for inference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: if we make the flag an integer, e.g. num_logits_to_keep: Optional[int], then it can easily become compatible with assisted generation

if num_logits_to_keep is None:
   logits = self.lm_head(hidden_states)
else:
   logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante - just to make sure I understand: If we do that, we'll still need to modify the assisted generation code to set num_logits_to_keep to candidate_length correct?

@@ -0,0 +1,830 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.



@require_torch
@unittest.skip("Update once we have a tiny Jamba model")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great TODO! Tiny logits would be awesome!

  • expected loss to make sure we compute the same!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only thing left here!

…rnorms and do it directly in the forward pass
…ions only if not None.
…ers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel
…result.router_logits now holds results only for expert layers
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks a lot cleaner

Comment on lines 91 to 93
n_ctx (`int`, *optional*, defaults to 262144):
This value doesn't have any real effect. The maximum sequence length that this model is intended to be
used with. It can be used with longer sequences, but performance may degrade.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's a very good point. Let's us max_position_embeddings and also include in the comment that it's use for evaluating! (it's not doing nothing!)

Comment on lines 259 to 260
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

Comment on lines 752 to 753
if self.attention_layer_idx is not None and layer_idx == self.attention_layer_idx:
self._seen_tokens += key_states.shape[-2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we no longer use the self._seen_tokens arg, and rely on cache_positions, should simplify things

Comment on lines 1693 to 1713
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No need for the cache position you can pass the past_length

Comment on lines 1876 to 1879
if calc_logits_for_entire_prompt:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -1:, :])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. Let's leave it for now, this will be streamlined to generate as this can be a generation config argument set if you use use_cache and no assisted_decoding.
Will be deprecated in a near futur, as this is mostly for inference

"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"calc_logits_for_entire_prompt": self.config.calc_logits_for_entire_prompt,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use_cache this could always be set to False if there is no self.generation_config.assistant could be handled here

tests/models/jamba/test_modeling_jamba.py Show resolved Hide resolved
…he model is too big to download (in docstring of JambaForCausalLM.forward)
Comment on lines 1420 to 1423
past_seen_tokens = (
past_key_values.get_seq_length()
past_key_values.get_seq_length(self.config.layers_block_type.index("attention"))
if isinstance(past_key_values, HybridMambaAttentionDynamicCache)
else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we should assume cache position are passed for this mode

Comment on lines 722 to 728
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
if self.layers_block_type[layer_idx] == "mamba":
raise ValueError("Can't return seq_length from Mamba layers cache as it doesn't have a sequence length dimension.")
return self.key_cache[layer_idx].shape[-2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would rather not have this, cache positions SHOULD be passed. That is here in llama for legacy

Comment on lines 743 to 748
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for both

@ArthurZucker
Copy link
Collaborator

All the rest you added LGTM

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tomeras91 tomeras91 marked this pull request as draft April 17, 2024 13:22
…_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model
@tomeras91 tomeras91 marked this pull request as ready for review April 17, 2024 15:25
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 small comments and should be good to go!

src/transformers/models/jamba/modeling_jamba.py Outdated Show resolved Hide resolved
src/transformers/models/jamba/modeling_jamba.py Outdated Show resolved Hide resolved
src/transformers/models/jamba/modeling_jamba.py Outdated Show resolved Hide resolved
…. Adjust test (test_decoder_model_past_with_large_inputs) accordingly
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Great work everyone!

@ArthurZucker ArthurZucker merged commit 3f20877 into huggingface:main Apr 18, 2024
23 checks passed
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* Add jamba arch

* apply "make fix-copies" changes

* fix link to model in JambaConfig docstring

* Add n_ctx in modeling file because repo-consistency wants that

* Add jamba to flash attention and sdpa documentation

* mamba dt_proj quant fix now works for LoRA as well

* override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers

* add jamba to tokenization auto

* fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24)

* simple PR fixes

* remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer

* remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (huggingface/peft#1530)

* Add copied comment on JambaMLP (it's the same as MixtralMLP)

* remove padding_mask warnings. It's not supported anymore

* fix docstring. Float instead of int

* A few more minor PR fixes

* (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass

* Return None attention weights from mamba layers. Append to all attentions only if not None.

* remove some leftover jamba archive lists

* Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel

* no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers

* Add Jamba paper on READMEs

* (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes)

* Add copied from comment

* remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms

* clearer docstring for _convert_to_standard_cache

* style fixes

* Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs

* rename test so it still overrides what its meant to override

* draft

* oups

* nit

* remove more complexe logic

* fix names used in config

* fix fix fix

* style

* fix some more failing tests

* generate did not init the cache 🙃

* more small nits

* typo

* config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes

* fix init of pkv with torch.tensor()

* empty tensor

* fix some init issues

* stupid changes required by generate because it does not even support it's own DynamicCache class

* more fixes

* fix general assisted gen cache_position bug

* tests passing

* Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py

* fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache

* no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore

* fix docstrings and typehints for past_key_values

* style fixes

* fix docs

* change typehint due to copy from Mixtral

* forgot import

* import order

* Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward)

* Add integration test with tiny tandom Jamba model on hub

* fix flash attention cache shapes

* bring back forgotten hidden states

* rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model

* align integration test after modeling fixes

* bugfix - mamba can use precomputed states only of forward pass is on a single token

* bugfix - mamba can use precomputed states only if they match the batch size

* typo

* remove making _prepare_4d_causal_attention_mask a leaf function

* stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants