Skip to content

Code for the 3rd place solution to Freesound Audio Tagging 2019 Challenge

License

Notifications You must be signed in to change notification settings

AlexanderKUA/freesound-classification

 
 

Repository files navigation

3rd place solution to Freesound Audio Tagging 2019 Challenge

My approach is outlined below.

Models

I used two types of models, both are based on convolutions. The first type uses 2d convolutions and works on top of mel-scale spectrograms, while the second uses 1d-convolutions on top of raw STFT representations with relatively small window size like 256, so it's only 5 ms per frame or so. Both types of models are relatively shallow and consist of 10-12 convolutional layers (or 5-6 resnet blocks) with a small number of filters. I use a form of deep supervision by applying global max pooling after each block (typically starting from the first or second block) and then concatenating maxpool outputs from each layer to form the final feature vector which then goes to a 2-layer fully-connected classifier. I also tried using RNNs instead of a max pooling for some models. It made results a bit worse, but RNN seemed to make different mistakes, so it turned out to be a good member of the final ensemble.

Frequency encoding

2d convolutions are position-invariant, so the output of a convolution would be the same regardless of where the feature is located. Spectrograms are not images, Y-axis corresponds to signal frequency, so it would be nice to assist a model by providing this sort of information. For this purpose, I used a linear frequency map going from -1 to 1 and concatenated it to input spectrogram as a second channel. It's hard to estimate now without retraining all the models how much gain I got from this little modification, but I can say It was no less than 0.005 in terms of local CV score.

This is not really a classification task

Most teams treated the problem as a multilabel classification and used a form of a binary loss such as binary cross entropy or focal loss. This approach is definitely valid, but in my experiments, it appeared to be a little suboptimal. The reason is the metric (lwlrap) is not a pure classification metric. Contrary to accuracy or f-score, it is based on ranks. So it wasn't really a surprise for me when I used a loss function based on ranks rather than on binary outputs, I got a huge improvement. Namely, I used something called LSEP (https://arxiv.org/abs/1704.03135) which is just a soft version of pairwise rank loss. It makes your model to score positive classes higher than negative ones, while a binary loss increases positive scores and decreases negative scores independently. When I switched to LSEP from BCE, I immediately got approximately 0.015 of improvement, and, as a nice bonus, my models started to converge much faster.

Data augmentation

I used two augmentation strategies. The first one is a modified MixUp. In contrast to the original approach, I used OR rule for mixing labels. I did so because a mix of two sounds still allows you to hear both. I tried the original approach with weighted targets on some point and my results got worse.

The second strategy is augmentations based on audio effects such as reverb, pitch, tempo and overdrive. I chose the parameters of these augmentations by carefully listening to augmented samples.

I have found augmentations to be very important for getting good results. I guess the total improvement I got from these two strategies is about 0.05 or so. I also tried several other approaches such as splitting the audio into several chunks and then shuffling them, replacing some parts of the original signals with silence and some other, but they didn't make my models better.

Training

I used quite large audio segments for training. For most of my models, I used segments from 8 to 12 seconds. I didn't use TTA for inference and used full-length audio instead.

Noisy data

I tried several unsupervised approaches such as Contrastive Predicting Coding, but never managed to get good results from it.

I ended up applying a form of iterative pseudolabeling. I predicted new labels for the noisy subset using a model trained on curated data only, chose best 1k in terms of the agreement between the predicted labels and actual labels and added these samples to the curated subset with the original labels. I repeated the procedure using top 2k labels this time. I applied this approach several times until I reached 5k best noisy samples. At that point, predictions generated by a model started to diverge significantly from the actual noisy labels. I decided to discard the labels of the remaining noisy samples and simply used model prediction as actual labels. In total, I trained approximately 20 models using different subsets of the noisy train set with different pseudolabeling strategies.

Inference

I got a great speed-up by computing both STFT spectrograms and mel spectrograms on a GPU. I also grouped samples with similar lengths together to avoid excessive padding. These two methods combined with relatively small models allowed me to predict the first stage test set in only 1 minute by any of my models (5 folds).

Final ensemble

For the final solution, I used a simple average of 11 models trained with slightly different architectures (1d/2d cnn, rnn/no-rnn), slightly different subsets of the noisy set (see "noisy data" section) and slightly different hyperparameters.

Project structure

Main training scripts are train_2d_cnn.py and train_hierarcical_cnn.py. All classification models are defined in networks/classifiers. All data augmentations are defined in ops/transforms.

Setting up the environment

I recommend using some environment manager such as conda or virtualenv in order to avoid potential conflicts between different versions of packages. To install all required packages, simply run pip install -r requirements.txt. This might take up to 15 minutes depending on your internet connection speed.

Preparing data

I place all the data into data/ directory, please adjust the following code to match yours data location. Run

python create_class_map.py --train_df data/train_curated.csv --output_file data/classmap.json

This simply creates a JSON file with deterministic classname->label mapping used in all future experiments.

Running a basic 2d model

python train_2d_cnn.py \
  --train_df data/train_curated.csv \
  --train_data_dir data/train_curated/ \
  --classmap data/classmap.json \
  --device=cuda \
  --optimizer=adam \
  --folds 0 1 2 3 4 \
  --n_folds=5 \
  --log_interval=10 \
  --batch_size=20 \
  --epochs=20 \
  --accumulation_steps=1 \
  --save_every=20 \
  --num_conv_blocks=5 \
  --conv_base_depth=50 \
  --growth_rate=1.5 \
  --weight_decay=0.0 \
  --start_deep_supervision_on=1 \
  --aggregation_type=max \
  --lr=0.003 \
  --scheduler=1cycle_0.0001_0.005 \
  --test_data_dir data/test \
  --sample_submission data/sample_submission.csv \
  --num_workers=6 \
  --output_dropout=0.0 \
  --p_mixup=0.0 \
  --switch_off_augmentations_on=15 \
  --features=mel_2048_1024_128 \
  --max_audio_length=15 \
  --p_aug=0.0 \
  --label=basic_2d_cnn

Running a 2d model with augmentations

python train_2d_cnn.py \
  --train_df data/train_curated.csv \
  --train_data_dir data/train_curated/ \
  --classmap data/classmap.json \
  --device=cuda \
  --optimizer=adam \
  --folds 0 1 2 3 4 \
  --n_folds=5 \
  --log_interval=10 \
  --batch_size=20 \
  --epochs=100 \
  --accumulation_steps=1 \
  --save_every=20 \
  --num_conv_blocks=5 \
  --conv_base_depth=100 \
  --growth_rate=1.5 \
  --weight_decay=0.0 \
  --start_deep_supervision_on=1 \
  --aggregation_type=max \
  --lr=0.003 \
  --scheduler=1cycle_0.0001_0.005 \
  --test_data_dir data/test \
  --sample_submission data/sample_submission.csv \
  --num_workers=16 \
  --output_dropout=0.5 \
  --p_mixup=0.5 \
  --switch_off_augmentations_on=90 \
  --features=mel_2048_1024_128 \
  --max_audio_length=15 \
  --p_aug=0.75 \
  --label=2d_cnn

Note that each such run is followed by a creation of a new experiment subdirectory in the experiments folder. Each experiment has the following structure:

experiments/some_experiment/
├── checkpoints
├── command
├── commit_hash
├── config.json
├── log
├── predictions
├── results.json
└── summaries

Using a clean model to select noisy samples

Create a new predictions directory:

mkdir predictions/

Then, running

python predict_2d_cnn.py \
  --experiment=path_to_an_experiment (see above) \
  --test_df=data/train_noisy.csv \
  --test_data_dir=data/train_noisy/ \
  --output_df=predictions/noisy_probabilities.csv \
  --classmap=data/classmap.json \
  --device=cuda

creates a new csv file in the predictions folder with the class probabilties for the noisy dataset.

Running

python relabel_noisy_data.py \
  --noisy_df=data/train_noisy.csv \
  --noisy_predictions_df=predictions/noisy_probabilities.csv \
  --output_df=predictions/train_noisy_relabeled_1k.csv \
  --mode=scoring_1000

creates a new noisy dataframe where only top 1k labels in terms of agreement between the model and the actual labels are kept.

Running a 2d model with noisy data

python train_2d_cnn.py \
  --train_df data/train_curated.csv \
  --train_data_dir data/train_curated/ \
  --noisy_train_df predictions/ train_noisy_relabeled_1k.csv \
  --noisy_train_data_dir data/train_noisy/ \
  --classmap data/classmap.json \
  --device=cuda \
  --optimizer=adam \
  --folds 0 1 2 3 4 \
  --n_folds=5 \
  --log_interval=10 \
  --batch_size=20 \
  --epochs=150 \
  --accumulation_steps=1 \
  --save_every=20 \
  --num_conv_blocks=6 \
  --conv_base_depth=100 \
  --growth_rate=1.5 \
  --weight_decay=0.0 \
  --start_deep_supervision_on=1 \
  --aggregation_type=max \
  --lr=0.003 \
  --scheduler=1cycle_0.0001_0.005 \
  --test_data_dir data/test \
  --sample_submission data/sample_submission.csv \
  --num_workers=16 \
  --output_dropout=0.7 \
  --p_mixup=0.5 \
  --switch_off_augmentations_on=140 \
  --features=mel_2048_1024_128 \
  --max_audio_length=15 \
  --p_aug=0.75 \
  --label=2d_cnn_noisy

Note that relabel_noisy_data.py script supports multiple relabeling straregies. I mostly followed "scoring" strategy (selecting top-k noisy samples based on the agreement between the model and the actual labels), but after 5k noisy samples I switched to "relabelall-replacenan" strategy which is just a pseudolabeling (usage of the old model outputs) where the samples without any predictions are discarded.

About

Code for the 3rd place solution to Freesound Audio Tagging 2019 Challenge

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%