Skip to content

methi1999/rnn_kws

Repository files navigation

Description

The following repository contains a PyTorch-based recurrent neural network trained on the TIMIT dataset. It gives frame-level phones as output and is trained using CTC loss. The following architectures are implemented: LSTM, GRU, TCNN (and their bidirectional versions). A custom implementation of the recurrent neural networks is also provided which makes it easy to modify the core equations.

Directory Structure

Folders:

  • architectures/: various RNN architectures.
  • models/: the dumped models which are used during training and inference.
  • pickle/: the insertion, deletion and substitution probabilities.
  • plots/: the train/test loss vs epoch figures and the PER on TIMIT Test as a function of the epochs.

Files:

  • config.yaml: contains all hyperparameters used for training/inference.
  • beam_search.py: implementation of beam search.
  • metadata.py: reads the TIMIT dataset and returns a list of feature vectors and ground truth phones for each recording.
  • dataloader.py: pads the sequences and returns batches which are used for training/testing.
  • dl_model.py: contains the actual train/infer functions. Starting point of the code as far as training the model is concerned.
  • utils.py: some common functions which are used throughout the project.
  • hypo_search.py: contains functions which traverse the lattice generated by the RNN and picks out the best subsequence given a target sequence. Currently checking only the top-5 lattices.
  • extract_q_values.py: computes Q-values as specified in the paper for each phone which is used as a threshold during inference.
  • infer.py: specify keywords here and carry out grid search over the hyperparameters.

Usage

For training:

  1. Place the parent directory of TIMIT dataset in the config.yaml file (config['dir']['dataset']).
  2. Customise the hyperparameters in config.yaml.
  3. Run the function train in dl_model.py.
  4. Models are periodically dumped in the following folder: models/<name_of_model>_<number_of_layers>_<number_of_hidden_units>_<number_of_audio_features> e.g. GRU_5_284_79 for a 5-layer GRU with 384 hidden units and 79 audio features.

For inference:

  1. Ensure that the model/<model_name> (as specified above) folder contains the pre-trained model.
  2. Run the function infer in dl_model.py with argument as a list of file paths of .wav files which are to be passed through the model. The function returns a list of tensors, each of shape time_step x number_of_phones

Check the commented code at the very end in dl_model.py for an illustration of how the results are generated.

About

Keyword spotting using RNNs + Edit distance

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages