Skip to content

Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 11, 2025

Conversation

VProv
Copy link
Contributor

@VProv VProv commented Mar 3, 2025

Describe your changes

This PR adds support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO).

Verified

This commit was signed with the committer’s verified signature.
bhrutledge Brian Rutledge

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
@VProv VProv requested a review from punkerpunker March 4, 2025 13:56
@VProv VProv requested review from azahed98 and mryab March 4, 2025 14:29
@mryab mryab removed the request for review from punkerpunker March 4, 2025 18:50
Comment on lines 154 to 156
filtered_messages.append(
{column: message[column] for column in REQUIRED_COLUMNS_MESSAGE}
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm not sure if filtering files when they are uploaded is the right solution: this will require users to reupload their data whenever we support a new field for messages (for example, function calling)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, removed the filtering part from the function

)

if not isinstance(example["preferred_output"], list):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these should be InvalidFileFormatError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@mryab mryab changed the title Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO). Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) Mar 5, 2025
VProv and others added 3 commits March 5, 2025 17:53
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
@VProv VProv requested a review from mryab March 5, 2025 18:49
VProv added 3 commits March 5, 2025 11:04
@mryab mryab requested review from artek0chumak and removed request for azahed98 March 10, 2025 11:21
Training method type for SFT training
"""

method: str = "sft"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
method: str = "sft"
method: Literal["sft"] = "sft"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


has_weights = False
# Check for weights in messages
if _has_weights(messages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make this into a separate function? Why not to inline it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can even be like

has_weights = any("weight" in message for message in messages)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

)
previous_role = message["role"]

return messages, has_weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to return messages? The row doesn't seem to be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return messages, has_weights


def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to return an example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -105,6 +109,12 @@ def createFinetuneRequest(
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: since you're using the | notation to specify union types above, I would use it here as well and remove the redundant import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

Comment on lines 130 to 133
has_weights = False
# Check for weights in messages
if _has_weights(messages):
has_weights = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it just the following? :)

Suggested change
has_weights = False
# Check for weights in messages
if _has_weights(messages):
has_weights = True
has_weights = _has_weights(messages)



def validate_messages(
messages: List[Dict[str, str | bool]], idx: int = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to imagine a case where we would want to use the default line number, maybe it's best to remove the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines 222 to 224
example["input"]["messages"], _ = validate_messages(
example["input"]["messages"], idx
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't modify anything in messages, I would simply make validate_messages return nothing and raise an exception in case of an error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path):
# Test various structural issues in OpenAI preference format
test_cases = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use pytest.mark.parametrize for iterating over multiple test cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -80,45 +128,149 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path):
def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path):
# Create a valid JSONL file with conversational format and multiple user-assistant turn pairs
file = tmp_path / "valid_conversational_multiple_turns.jsonl"
content = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to keep the current file for this test and write a new one for , because

  1. Unit tests should test orthogonal capabilities, otherwise this gets misleading when an error is introduced (improper parsing of preference data should not affect tests for regular conversation datasets)
  2. Right now, it actually looks like this test is now identical to test_check_jsonl_valid_preference_openai, which is unlikely to be what you want :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a separate file

@VProv VProv requested review from mryab and artek0chumak March 11, 2025 14:15
Comment on lines 108 to 111
AVAILABLE_TRAINING_METHODS = {
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a constant, can you move it to the top of the file (outside of the function and the class definition)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

lrScheduler = FinetuneLRScheduler(
lr_scheduler_type="linear",
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe annotate the type as training_method_cls: TrainingMethod? It's a bit clearer and more extensible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issues with pre-commit checks when I tried to do this, as I remember

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future

assert report["has_min_samples"]


# Define test cases for missing fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment seems redundant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

from together.constants import MIN_SAMPLES
from together.utils.files import check_file

# Test data for preference OpenAI format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one's also not very informative given the name of the variable

assert not report["is_check_passed"], f"Test should fail when {description}"


# Define test cases for structural issues
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

VProv added 2 commits March 11, 2025 08:46
@VProv VProv requested a review from mryab March 11, 2025 15:58
lrScheduler = FinetuneLRScheduler(
lr_scheduler_type="linear",
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future

assert not report["is_check_passed"], f"Test should fail when {description}"


STRUCTURAL_ISSUE_TEST_CASES = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the constant can be made private

assert report["has_min_samples"]


MISSING_FIELDS_TEST_CASES = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the constant can be made private

@mryab mryab merged commit a4fd112 into main Mar 11, 2025
10 of 11 checks passed
@mryab mryab deleted the Vprov/dpo_python branch March 11, 2025 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants