Optimized batch processing for Bert sentence and word embeddings #6462
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is an optimization of batch processing in BertSentenceEmbeddings and BertEmbeddings model which improves performance of single machine setups where there are few/single sentence per row.
Description
The HasBatchedAnnotate trait lets models process several rows at once. The idea is to enable batching annotations across rows which can significantly speed up processing, especially on GPU hardware. However, the current implementation of batchAnnotate method in BertSentenceEmbeddings and BertEmbeddings doesn't actually allows this as they process rows one by one therefore and can therefore batch only annotations within a single row. This PR solves the problem by adding all annotations which are passed to batchAnnotate, regardless of which row they belong to, to a single collection and then passing this collection to TensorflowBert to process them using the given batch size.
The change leads to significant performance gain on single machine setup when there is a single or a few annotations per row. Below are the results of the BertEmbeddings benchmark test on the conll2004 dataset:
The test is run on a single machine with a GPU. GPU performance is significantly improved when sentence are exploded (i.e. a single sentence per row).
Additional tests revealed that the changes make little to no impact on performance in a cluster setup with multiple GPUs. Apparently such a Spark setup is capable of utilizing the GPUs even when each call of batchAnnotate is feeding the input tensors one by one.
I've re-implemented batchAnnotate method of BertSentenceEmbeddings and BertEmbeddings. Instead of processing the annotations separately for each row, I collect put all the annotations in a single collection and make sure the allocation of annotations per rows is preserved by zipping annotations to their row index. The all annotations are tokenized and passed to the TF graph to compute their (sentence or word) embeddings. At the end embeddings are again distributed across rows.
Motivation and Context
It lets BertSentenceEmbeddings and BertEmbeddings processes sentences in batches when there are a few annotations per row. In such cases, the change significantly improves performance on single machine setups.
How Has This Been Tested?
I've made sure the results of the improved batch processing are identical to the ones produced by the current implementation. This test uses different JAR files so it is not included in the current commit.
I've also tested performance in order to demonstrate the changes indeed improve performance (see table above).
Test were run on a local SparkNLP installation.
Changes don't have any impact on functionality and shouldn't affect or break any existing code
Screenshots (if appropriate):
Types of changes
It is actually neither of the above, we need one more option 'code optimization with no functional implications but significant impact on performance'
Checklist: