From 0bca5a3f0779860af9d5fcd5cc6b03cd6de54d89 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 7 Mar 2023 14:53:44 -0500 Subject: [PATCH] Make sure to include padding mask in generation (#2096) --- test/integration_tests/test_generate.py | 34 +++++++++++++++---------- torchtext/models/t5/model.py | 2 ++ torchtext/prototype/generate.py | 21 +++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/test/integration_tests/test_generate.py b/test/integration_tests/test_generate.py index 80acd9dc7e..df5f8b81bc 100644 --- a/test/integration_tests/test_generate.py +++ b/test/integration_tests/test_generate.py @@ -14,26 +14,25 @@ def setUp(self) -> None: self.model = t5_base.get_model() self.model.eval() # Examples taken from T5 Paper and Huggingface - self.inputs = self.transform( - [ - "summarize: studies have shown that owning a dog is good for you", - "translate English to German: That is good.", - "cola sentence: The course is jumping well.", - "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", - "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", - ] - ) + self.inputs = [ + "summarize: studies have shown that owning a dog is good for you", + "translate English to German: That is good.", + "cola sentence: The course is jumping well.", + "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", + "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", + ] + self.transformed_inputs = self.transform(self.inputs) torch.manual_seed(0) def test_greedy_generate_with_t5(self) -> None: generation_model = GenerationUtils(self.model) - tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30) + tokens = generation_model.generate(self.transformed_inputs, num_beams=1, max_length=30) generated_text = self.transform.decode(tokens.tolist()) expected_generated_text = [ - "a dog is good for you, according to studies . owning a dog is good for you, according to studies .", - "Das ist gut.", + "owning a dog is good for you, according to studies . a dog is a good companion for a variety of reasons", + "Das ist gut so.", "acceptable", "4.0", "mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", @@ -41,14 +40,21 @@ def test_greedy_generate_with_t5(self) -> None: self.assertEqual(generated_text, expected_generated_text) + inputs = self.transform([self.inputs[0]]) + + tokens_for_single_example = generation_model.generate(inputs, num_beams=1, max_length=30) + generated_text_for_single_example = self.transform.decode(tokens_for_single_example.tolist()) + + self.assertEqual(generated_text[0], generated_text_for_single_example[-1]) + def test_generate_errors_with_incorrect_beams(self) -> None: generation_model = GenerationUtils(self.model, is_encoder_decoder=True) with self.assertRaises(ValueError): - generation_model.generate(self.inputs, num_beams=0) + generation_model.generate(self.transformed_inputs, num_beams=0) @patch("logging.Logger.warning") def test_warns_when_no_max_len_provided(self, mock) -> None: generation_model = GenerationUtils(self.model) - generation_model.generate(self.inputs) + generation_model.generate(self.transformed_inputs) mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") diff --git a/torchtext/models/t5/model.py b/torchtext/models/t5/model.py index eebc9ce83e..6ba55089c5 100644 --- a/torchtext/models/t5/model.py +++ b/torchtext/models/t5/model.py @@ -198,6 +198,7 @@ def prepare_inputs_for_generation( self, input_ids: Tensor, encoder_outputs: ENCODER_OUTPUTS_TYPE, + encoder_padding_mask: Optional[Tensor] = None, past: Optional[List[PAST_KEY_VALUES_TYPE]] = None, return_past_key_values: bool = True, ) -> Dict[str, Union[Tensor, ENCODER_OUTPUTS_TYPE, Optional[List[PAST_KEY_VALUES_TYPE]], bool]]: @@ -209,6 +210,7 @@ def prepare_inputs_for_generation( "decoder_tokens": input_ids, "encoder_outputs": encoder_outputs, "past_key_values": past, + "encoder_padding_mask": encoder_padding_mask, "return_past_key_values": return_past_key_values, } diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py index dd74948c81..3300fda11f 100644 --- a/torchtext/prototype/generate.py +++ b/torchtext/prototype/generate.py @@ -48,7 +48,7 @@ def _prepare_decoder_ids_for_generation( return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx def greedy_search( - self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs + self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: int, **model_kwargs ) -> torch.Tensor: """Greedy search decoding for text generation. Takes the most likely next token every time. @@ -62,10 +62,11 @@ def greedy_search( Returns: Batch of sequences decoded by greedy search. """ - unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long) + unfinished_sequences = torch.ones((input_ids.shape[0]), device=input_ids.device, dtype=torch.long) while True: model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) + if self.is_huggingface_model: model_inputs["return_dict"] = True model_inputs["output_hidden_states"] = True @@ -77,18 +78,16 @@ def greedy_search( # Calculate probabilities and take the most likely next token probs = F.log_softmax(decoder_output[:, -1], dim=-1) - _, next_tokens = torch.topk(probs, 1) + next_tokens = torch.argmax(probs, dim=-1) # For any finished sequences, padding idx should be the last token - if eos_idx is not None: - if pad_idx is not None: - next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) + next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) # Append the next tokens to the previous tokens - input_ids = torch.cat([input_ids, next_tokens], dim=-1) + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if eos_idx is not None: - unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long()) + # Update unfinished sequences count + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx)).long() # Stop iterating once all sequences are finished or exceed the max_length if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length: @@ -128,8 +127,10 @@ def generate( if self.is_encoder_decoder: encoder = self.model.get_encoder() - model_kwargs["encoder_outputs"] = encoder(inputs) + encoder_model_kwargs = {"src_key_padding_mask": inputs.eq(pad_idx)} + model_kwargs["encoder_outputs"] = encoder(inputs, **encoder_model_kwargs) inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs) + model_kwargs["encoder_padding_mask"] = encoder_model_kwargs.pop("src_key_padding_mask") if max_length is None: # Too hard to try to figure out the exact max_seq_length for each model