Artificial Intelligence (AI)
Discuss current events in AI and technological innovations with Intel® employees
485 Discussions

Best Practices for Text Classification with Distillation (Part 1/4) - How to achieve BERT results by

Moshe_Wasserblat
Employee
0 0 712

Published May 17th, 2021

Moshe Wasserblat is currently the research manager for Natural Language Processing (NLP) at Intel Labs.


In recent years, increasingly large Transformer-based models such as BERT have demonstrated remarkable state-of-the-art (SoTA) performance in many Natural Language Processing (NLP) tasks and have become the de-facto standard. However, these models are highly inefficient and require massive computational resources and large amounts of data for training and deploying. As a result, the scalability and deployment of NLP-based systems across the industry is severely hindered.

I have always been fascinated by the robustness and efficiency of NLP deployed in production. So, I decided to write a blog series that will provide some practical tips and code samples for deploying and adapting large SoTA transformer models.

In the first four blogs, I will focus on model distillation for text classification. Model distillation is a powerful pruning technique, and in many use cases, it yields significant speedup and memory size reduction. It’s generally considered more suitable for advanced users because it is relatively hard to implement, its performance is unpredictable, and the inherent mechanism is obscure. I will try to show just the opposite by providing a few real use-case examples with simple code snippets and intuitive explanations to show the effectiveness of model distillation.

Let’s begin.

Suppose you are a data scientist in a top enterprise company and your task is to classify social media tweets and deliver SoTA models into production. You were also warned that your model must be very efficient, and you will pay a high cost for any extra million parameters. You will probably start by collecting sufficient data (several hundreds or thousands of labeled samples) and then you’ll compare several ML/DL models and Transformer models for maximum accuracy at minimum cost (i.e., minimal model size).

To demonstrate the accuracy achieved by different model types, I chose an emotion classification dataset called Emotion that consists of Twitter posts labeled with any of six basic emotion categories: sadness, disgust, anger, joy, surprise, and fear. The data consists of 16K training samples and 2K test samples and is available on the Hugging Face Datasets Hub. A code example for the following steps is available here.

1st step: Set a baseline using a logistic regression model (Tf-Idf based)

After performing the first step, we get:

 Accuracy = 86.1%

 That is our baseline result.

2nd step: Set a deep learning baseline

Next, we try a simple multilayer perceptron (MLP) model. The model architecture is basic and includes an input dimension size of 5000 (maximum word vocabulary size), output dimension size of 16, average pooling layer, and softmax (in total ~80K parameters). 

We get:

 Accuracy = 86%

The accuracy is similar to the accuracy of the logistic regression, but the model is more efficient and denser. 

 3rd step: Transformer models

As NLP veterans and practiced users of Hugging Face and their excellent “Transformers” library, we try a few popular SoTA transformer models.

 We get:

1AA07CB4-5047-4E29-A945-0A9FBD92C3F3.jpeg

 

Accuracy shot sky-high. But 110M parameters is way above our computational budget, and IT will hit the ceiling when they hear about the 220M model.

 4th step: DistilBERT

So, let’s try one of the more popular models like DistilBERT or DistilRoBERTa released by Hugging Face instead. It’s half the size yet double the speed compared to the BERT-base model.

 We get:

944E1039-135B-4E23-AEFF-EDEFCCA546BE.jpeg

 

DistilBERT’s accuracy dropped by less than one percent compared with the BERT-base model, and our model is much smaller.

Is that it? Are we done? Should we go to production with this model and pay the computing cost for the 66M parameters? Or can we do better? Can we use less than 1M or 100K parameters while maintaining minimal accuracy loss (<1% loss)?

Surprisingly, for many practical text classification tasks, the answer is “Yes!” (to the degree that results vary with your dataset quality and task at hand). We will harness specific-purpose knowledge distillation and additional data either by utilizing data augmentation techniques or sampled from your in-domain unlabeled dataset.

As you know, BERT is a pre-trained language model trained for the masked language model task. In our case, we are interested only in emotion classification, and previous research shows that fine-tuned BERT parameters are over-parameterized for domain-specific tasks (Kovaleva et al., 2019).

Knowledge Distillation (KD) from large model to a much simpler architecture (Tang et al., 2019; Wasserblat et al., 2020) showed promising results for reducing the model size and computational load while preserving much of the original model’s accuracy A typical KD setup includes two stages. In the first stage, a large, cumbersome, and accurate teacher neural network for a specific downstream task. In the second stage, shown in the following figure, a smaller and simpler student model is trained to mimic the behavior of the teacher model. This model is more practical for deployment in environments with limited resources.

9E50D6B8-C8EF-43BA-A5A4-6DAA16AB1852.jpeg

Code Disclaimer: To make the system super easy, we only use a single loss generated for each training batch by calculating the Kullback–Leibler (KL) distance between the target predictions produced by the student and teacher models. We didn’t notice any performance loss when deploying Mean-Square-Error (MSE) between soft targets (logits) or when we employed temperature in the original distillation paper (Hinton et al., 2015). Distillation setups are usually cumbersome due to the use of different loss types for labeled and unlabeled data. In contrast, our implementation is simpler and friendlier since it only uses a single loss (KL loss) for both types of data.

 5th step: Distill RoBERTa to a simpler student model

In our case, we distill RoBERTa’s knowledge into our simple MLP model.

Here are the results:

F2F479D7-19ED-4AB6-B7C1-8119EAE77527.jpeg


Surprisingly, not bad at all, with accuracy even higher than DistillBERT and on-par with BERT!

The following figure summarizes the results that we achieved so far for Emotion (model accuracy/BERT-Base accuracy)% vs. model size.

E99D2F4A-E8C1-438A-A799-A02C75D9681E.jpeg

We distilled RoBERTa’s knowledge into our tiny model with almost no loss. We benefitted from high transformer model performance with minimal cost to pay (and made IT very happy).

By now you probably have many questions, such as:

●     Does this “trick” hold for any text-classification sub-task?

●     If not, when does it work?

●     How would I choose the best student for my dataset?

●     What is the intuition behind this behavior?

I’ll explore these issues and answer these questions in the following posts.

 

Other Google Colab's References

[1] T5-base finetuned Emotion by Manu Romero

[2] DistilRoBERTa finetuned Emotion by Elvis Saravia

Acknowledgments

Special thanks to Oren Pereg - For all the comments and improvements in the reviewing process.



Tags (1)
About the Author
Mr. Moshe Wasserblat is currently Natural Language Processing (NLP) and Deep Learning (DL) research group manager at Intel’s AI Product group. In his former role he has been with NICE systems for more than 17 years and has founded the NICE’s Speech Analytics Research Team. His interests are in the field of Speech Processing and Natural Language Processing (NLP). He was the co-founder coordinator of EXCITEMENT FP7 ICT program and served as organizer and manager of several initiatives, including many Israeli Chief Scientist programs. He has filed more than 60 patents in the field of Language Technology and also has several publications in international conferences and journals. His areas of expertise include: Speech Recognition, Conversational Natural Language Processing, Emotion Detection, Speaker Separation, Speaker Recognition, Information Extraction, Data Mining, and Machine Learning.