Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weird Loss Curve #831

Open
Zihang-Xu-2002 opened this issue May 17, 2024 · 1 comment
Open

Weird Loss Curve #831

Zihang-Xu-2002 opened this issue May 17, 2024 · 1 comment

Comments

@Zihang-Xu-2002
Copy link

I trained the llama3 on my own conversation dataset with the command :
./scripts/run_finetune.sh
--model_name_or_path meta-llama/Meta-Llama-3-8B
--dataset_path data/alpaca_selected/train
--conversation_template llama3
--output_model_path output_models/finetuned_llama3_8b_selected

The initial learning rate is 2e-5 and batchsize_per_device is 4
And I found there are sharp drops at the beginning of every epoch. But during the epoch, there's no obvious loss drop.
image

Before this I trained llama2
./scripts/run_finetune.sh
--model_name_or_path meta-llama/Llama-2-7b-hf
--dataset_path data/alpaca_raw/train
--conversation_template llama2
--output_model_path output_models/finetuned_llama2_7b_raw

The initial learning rate is 8e-6 and batchsize_per_device is 4. The loss looks like :
image

I am not sure if the gradient accumulation leads to this. I modified the "gradient_accumulation_steps" in configs/ds_config_zero3.json to 1 . But there's no changes.
image

Could you help me with this issue? Thank you for your time and attention.

@Zihang-Xu-2002 Zihang-Xu-2002 changed the title Wired Loss Curve Weird Loss Curve May 17, 2024
@research4pan
Copy link
Contributor

research4pan commented May 19, 2024

Thanks for your interest in LMFlow! We've observed similar loss curves in some of our experiments. After careful examination, we attributed this to the overfitting of instruction following dataset on llama models. Inside each epoch, the flattened loss curve may come from the large variance of the dataset, decreasing the learning rate or increasing the batch size should help, though the overall tendency should remain the same.

You may check your evaluation/test results, if the results are normal then it may not be a serious issue 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants