Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add Mixtral] Adds support for the Mixtral MoE #27942

Merged
merged 111 commits into from
Dec 11, 2023
Merged
Changes from 1 commit
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
15ef543
up
younesbelkada Dec 8, 2023
3367d25
up
younesbelkada Dec 8, 2023
f9da444
test
younesbelkada Dec 8, 2023
f59eacc
logits ok
younesbelkada Dec 8, 2023
7e0968a
up
younesbelkada Dec 8, 2023
0bfcd75
up
younesbelkada Dec 8, 2023
6b84e42
few fixes
younesbelkada Dec 8, 2023
2896e2f
conversion script
younesbelkada Dec 8, 2023
92d143f
up
younesbelkada Dec 8, 2023
d3261c1
nits
ArthurZucker Dec 8, 2023
407f8a8
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 8, 2023
65bbd30
nits
ArthurZucker Dec 8, 2023
6afc8f3
update
ArthurZucker Dec 8, 2023
bfef811
Merge branch 'main' into add-mixtral-alternative
younesbelkada Dec 8, 2023
7a54d1a
nuke
younesbelkada Dec 8, 2023
b9f3fc0
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
younesbelkada Dec 8, 2023
f8513e8
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 8, 2023
0d31424
more updates
ArthurZucker Dec 8, 2023
c8987cb
nites
ArthurZucker Dec 8, 2023
d82c8ee
fix many issues
younesbelkada Dec 9, 2023
ccc6011
nit
younesbelkada Dec 9, 2023
356d484
scatter
ArthurZucker Dec 9, 2023
e858c01
nit
younesbelkada Dec 9, 2023
82037ca
nuke megablocks
younesbelkada Dec 9, 2023
e66d1a9
nits
ArthurZucker Dec 9, 2023
49eb7f0
fix conversion script
younesbelkada Dec 9, 2023
ffc8463
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 9, 2023
82e4a1b
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 9, 2023
0b1ca52
nit
younesbelkada Dec 9, 2023
3616d3b
remove
ArthurZucker Dec 9, 2023
6d73a58
nits
ArthurZucker Dec 9, 2023
4c1fbf3
nit
younesbelkada Dec 9, 2023
a922710
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 9, 2023
1abf6bd
update
ArthurZucker Dec 9, 2023
b938a30
oupsssss
ArthurZucker Dec 9, 2023
12ddba9
change
younesbelkada Dec 9, 2023
445e6e6
nits device
ArthurZucker Dec 9, 2023
1e83d0e
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 9, 2023
263310f
nits
ArthurZucker Dec 9, 2023
b2bedb1
fixup
ArthurZucker Dec 9, 2023
4afd7e4
update
ArthurZucker Dec 9, 2023
c0e6dfd
merge
ArthurZucker Dec 9, 2023
dd33a59
add copied from
ArthurZucker Dec 9, 2023
0de7081
fix the copy mentions
ArthurZucker Dec 9, 2023
48945de
update tests
ArthurZucker Dec 9, 2023
d927baf
more fixes
ArthurZucker Dec 9, 2023
7402aca
nits
ArthurZucker Dec 9, 2023
54bee10
conversion script
younesbelkada Dec 9, 2023
1ae98dc
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
younesbelkada Dec 9, 2023
8bf257f
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Dec 9, 2023
01b2969
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 9, 2023
80c593e
add parts of the readme
ArthurZucker Dec 9, 2023
dab227f
Update tests/models/mixtral/test_modeling_mixtral.py
younesbelkada Dec 9, 2023
ca1f7d0
new test + conversion script
younesbelkada Dec 9, 2023
e4237a3
Apply suggestions from code review
younesbelkada Dec 9, 2023
0c04bc3
Apply suggestions from code review
younesbelkada Dec 9, 2023
bd7c786
fix
younesbelkada Dec 9, 2023
b2b8e01
fix copies
younesbelkada Dec 9, 2023
badceae
fix copies
younesbelkada Dec 9, 2023
11a4db9
ooops
younesbelkada Dec 9, 2023
419ddb3
fix config
younesbelkada Dec 9, 2023
bbbd1b2
Apply suggestions from code review
younesbelkada Dec 9, 2023
2b23c47
fix nits
younesbelkada Dec 9, 2023
76a65e6
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
younesbelkada Dec 9, 2023
18caab8
nit
younesbelkada Dec 9, 2023
a00ad3a
add copies
younesbelkada Dec 9, 2023
657dd95
add batched tests
younesbelkada Dec 9, 2023
a092648
docs
younesbelkada Dec 10, 2023
67e8e03
fix flash attention
younesbelkada Dec 10, 2023
72542dd
let's add more verbose
younesbelkada Dec 10, 2023
d3f5abb
add correct outputs
ArthurZucker Dec 10, 2023
e900e36
support router ouptus
ArthurZucker Dec 10, 2023
68b8b41
ignore copies where needed
ArthurZucker Dec 10, 2023
ded6028
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 10, 2023
1ceb940
fix
ArthurZucker Dec 10, 2023
38eef46
cat list if list is given for now
ArthurZucker Dec 10, 2023
8d3f83f
nits
ArthurZucker Dec 10, 2023
ee5f3e9
Update docs/source/en/model_doc/mixtral.md
younesbelkada Dec 10, 2023
e834e89
finish router refactoring
ArthurZucker Dec 10, 2023
1b6358e
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 10, 2023
54f2a48
fix forward
ArthurZucker Dec 10, 2023
19be169
fix expected values
ArthurZucker Dec 10, 2023
5c929df
nits
ArthurZucker Dec 10, 2023
872ee24
fixup
ArthurZucker Dec 10, 2023
eaa2a5f
fix
younesbelkada Dec 10, 2023
703672d
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 10, 2023
10760c1
fix bug
younesbelkada Dec 10, 2023
3499c98
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
younesbelkada Dec 10, 2023
19e4aea
fix
ArthurZucker Dec 10, 2023
9bc7d5b
fix dtype mismatch
younesbelkada Dec 10, 2023
290f621
fix
ArthurZucker Dec 10, 2023
1f411a1
Merge branch 'add-mixtral-alternative' of https://github.com/huggingf…
ArthurZucker Dec 10, 2023
9ebf661
grrr grrr I support item assignment
ArthurZucker Dec 10, 2023
23abc46
fix CI
younesbelkada Dec 10, 2023
6549f48
docs
younesbelkada Dec 11, 2023
39e38ed
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker Dec 11, 2023
e4b84bc
fixup
ArthurZucker Dec 11, 2023
80d49aa
remove some copied form
ArthurZucker Dec 11, 2023
fbde97b
fix weird diff
younesbelkada Dec 11, 2023
20091dc
skip doctest fast on the config and modeling
ArthurZucker Dec 11, 2023
adc7113
mark that is supports flash attention in the doc
ArthurZucker Dec 11, 2023
c6ddca8
update
ArthurZucker Dec 11, 2023
3f62433
Update src/transformers/models/mixtral/modeling_mixtral.py
ArthurZucker Dec 11, 2023
6c6df4e
Update docs/source/en/model_doc/mixtral.md
ArthurZucker Dec 11, 2023
d4e826f
revert router logits config issue
ArthurZucker Dec 11, 2023
d17b756
update doc accordingly
ArthurZucker Dec 11, 2023
e86facd
Update src/transformers/models/mixtral/convert_mixtral_weights_to_hf.py
younesbelkada Dec 11, 2023
2c85405
nits
ArthurZucker Dec 11, 2023
bb88c76
use torch testing asssert close
ArthurZucker Dec 11, 2023
6624e9c
fixup
ArthurZucker Dec 11, 2023
c26aaa4
doc nits
ArthurZucker Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,11 +1226,6 @@ def forward(
logits = self.lm_head(hidden_states)
logits = logits.float()

if return_dict and output_router_logits:
router_logits = outputs.router_logits
else:
router_logits = outputs[-1]

loss = None
if labels is not None:
# Shift so that tokens < n predict n
Expand All @@ -1246,7 +1241,7 @@ def forward(

aux_loss = None
if output_router_logits:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting output_router_logits = True should automatically add the aux_loss

aux_loss = load_balancing_loss_func(router_logits, self.num_experts, self.num_experts_per_tok)
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss

Expand Down