-
Notifications
You must be signed in to change notification settings - Fork 16
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/together/utils/files.py
Outdated
filtered_messages.append( | ||
{column: message[column] for column in REQUIRED_COLUMNS_MESSAGE} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I'm not sure if filtering files when they are uploaded is the right solution: this will require users to reupload their data whenever we support a new field for messages (for example, function calling)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, removed the filtering part from the function
src/together/utils/files.py
Outdated
) | ||
|
||
if not isinstance(example["preferred_output"], list): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these should be InvalidFileFormatError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
…CE_OPENAI
src/together/types/finetune.py
Outdated
Training method type for SFT training | ||
""" | ||
|
||
method: str = "sft" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
method: str = "sft" | |
method: Literal["sft"] = "sft" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
src/together/utils/files.py
Outdated
|
||
has_weights = False | ||
# Check for weights in messages | ||
if _has_weights(messages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make this into a separate function? Why not to inline it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can even be like
has_weights = any("weight" in message for message in messages)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
src/together/utils/files.py
Outdated
) | ||
previous_role = message["role"] | ||
|
||
return messages, has_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to return messages
? The row doesn't seem to be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/together/utils/files.py
Outdated
return messages, has_weights | ||
|
||
|
||
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to return an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/together/resources/finetune.py
Outdated
@@ -105,6 +109,12 @@ def createFinetuneRequest( | |||
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | |||
) | |||
|
|||
training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: since you're using the |
notation to specify union types above, I would use it here as well and remove the redundant import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
src/together/utils/files.py
Outdated
has_weights = False | ||
# Check for weights in messages | ||
if _has_weights(messages): | ||
has_weights = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it just the following? :)
has_weights = False | |
# Check for weights in messages | |
if _has_weights(messages): | |
has_weights = True | |
has_weights = _has_weights(messages) |
src/together/utils/files.py
Outdated
|
||
|
||
def validate_messages( | ||
messages: List[Dict[str, str | bool]], idx: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to imagine a case where we would want to use the default line number, maybe it's best to remove the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
src/together/utils/files.py
Outdated
example["input"]["messages"], _ = validate_messages( | ||
example["input"]["messages"], idx | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't modify anything in messages, I would simply make validate_messages
return nothing and raise an exception in case of an error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
tests/unit/test_files_checks.py
Outdated
|
||
def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path): | ||
# Test various structural issues in OpenAI preference format | ||
test_cases = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use pytest.mark.parametrize
for iterating over multiple test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/unit/test_files_checks.py
Outdated
@@ -80,45 +128,149 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path): | |||
def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path): | |||
# Create a valid JSONL file with conversational format and multiple user-assistant turn pairs | |||
file = tmp_path / "valid_conversational_multiple_turns.jsonl" | |||
content = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to keep the current file for this test and write a new one for , because
- Unit tests should test orthogonal capabilities, otherwise this gets misleading when an error is introduced (improper parsing of preference data should not affect tests for regular conversation datasets)
- Right now, it actually looks like this test is now identical to
test_check_jsonl_valid_preference_openai
, which is unlikely to be what you want :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a separate file
src/together/resources/finetune.py
Outdated
AVAILABLE_TRAINING_METHODS = { | ||
TrainingMethodSFT().method, | ||
TrainingMethodDPO().method, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a constant, can you move it to the top of the file (outside of the function and the class definition)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
lrScheduler = FinetuneLRScheduler( | ||
lr_scheduler_type="linear", | ||
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
) | ||
|
||
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe annotate the type as training_method_cls: TrainingMethod
? It's a bit clearer and more extensible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were some issues with pre-commit checks when I tried to do this, as I remember
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future
tests/unit/test_preference_openai.py
Outdated
assert report["has_min_samples"] | ||
|
||
|
||
# Define test cases for missing fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment seems redundant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
tests/unit/test_preference_openai.py
Outdated
from together.constants import MIN_SAMPLES | ||
from together.utils.files import check_file | ||
|
||
# Test data for preference OpenAI format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one's also not very informative given the name of the variable
tests/unit/test_preference_openai.py
Outdated
assert not report["is_check_passed"], f"Test should fail when {description}" | ||
|
||
|
||
# Define test cases for structural issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
lrScheduler = FinetuneLRScheduler( | ||
lr_scheduler_type="linear", | ||
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
) | ||
|
||
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future
assert not report["is_check_passed"], f"Test should fail when {description}" | ||
|
||
|
||
STRUCTURAL_ISSUE_TEST_CASES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the constant can be made private
assert report["has_min_samples"] | ||
|
||
|
||
MISSING_FIELDS_TEST_CASES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the constant can be made private
Describe your changes
This PR adds support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO).