@@ -887,8 +887,26 @@ function createPositionIds(model_inputs, past_key_values = null, start_index = 0
887
887
}
888
888
889
889
function decoder_prepare_inputs_for_generation ( self , input_ids , model_inputs , generation_config ) {
890
+ const past_length = model_inputs . past_key_values
891
+ ? Object . values ( model_inputs . past_key_values ) [ 0 ] . dims . at ( - 2 )
892
+ : 0 ;
893
+
894
+ if ( ! model_inputs . attention_mask ) {
895
+ // If the attention mask is not provided, we attempt to infer based on provided inputs
896
+ let dims ;
897
+ for ( const key of [ 'input_ids' , 'inputs_embeds' , 'position_ids' ] ) {
898
+ if ( model_inputs [ key ] ) {
899
+ dims = model_inputs [ key ] . dims ;
900
+ break ;
901
+ }
902
+ }
903
+ if ( ! dims ) {
904
+ throw new Error ( "attention_mask is not provided, and unable to infer its shape from model inputs." ) ;
905
+ }
906
+ model_inputs . attention_mask = ones ( [ dims [ 0 ] , past_length + dims [ 1 ] ] ) ;
907
+ }
908
+
890
909
if ( model_inputs . past_key_values ) {
891
- const past_length = Object . values ( model_inputs . past_key_values ) [ 0 ] . dims . at ( - 2 ) ;
892
910
const { input_ids, attention_mask } = model_inputs ;
893
911
894
912
// Keep only the unprocessed tokens:
@@ -909,24 +927,7 @@ function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, ge
909
927
}
910
928
// 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
911
929
else {
912
- if (
913
- // NOTE: Only used by VLMs (!= so that null matches undefined)
914
- self . config . image_token_index != null &&
915
- // Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
916
- input_ids . data . some ( x => x == self . config . image_token_index )
917
- ) {
918
- // TODO: Support multiple image tokens
919
- const num_image_tokens = self . config . num_image_tokens ;
920
- if ( ! num_image_tokens ) {
921
- throw new Error ( '`num_image_tokens` is missing in the model configuration.' ) ;
922
- }
923
-
924
- const num_new_tokens = input_ids . dims [ 1 ] - ( past_length - num_image_tokens ) ;
925
- model_inputs . input_ids = input_ids . slice ( null , [ - num_new_tokens , null ] ) ;
926
930
927
- // TODO: The attention mask should be formed from the attention mask passed in model_inputs
928
- model_inputs . attention_mask = ones ( [ 1 , past_length + num_new_tokens ] ) ;
929
- }
930
931
}
931
932
}
932
933
@@ -2016,17 +2017,7 @@ export class PreTrainedModel extends Callable {
2016
2017
2017
2018
async encode_image ( { pixel_values } ) {
2018
2019
// image_inputs === { pixel_values }
2019
- const features = ( await sessionRun ( this . sessions [ 'vision_encoder' ] , { pixel_values } ) ) . image_features ;
2020
- // @ts -expect-error TS2339
2021
- if ( ! this . config . num_image_tokens ) {
2022
- console . warn (
2023
- 'The number of image tokens was not set in the model configuration. ' +
2024
- `Setting it to the number of features detected by the vision encoder (${ features . dims [ 1 ] } ).`
2025
- )
2026
- // @ts -expect-error TS2339
2027
- this . config . num_image_tokens = features . dims [ 1 ] ;
2028
- }
2029
- return features ;
2020
+ return ( await sessionRun ( this . sessions [ 'vision_encoder' ] , { pixel_values } ) ) . image_features ;
2030
2021
}
2031
2022
2032
2023
async encode_text ( { input_ids } ) {
@@ -3640,65 +3631,16 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
3640
3631
* The LLAVA model which consists of a vision backbone and a language model.
3641
3632
*/
3642
3633
export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
3634
+ _merge_input_ids_with_image_features ( kwargs ) {
3635
+ const vision_hidden_size = kwargs . image_features . dims . at ( - 1 ) ;
3636
+ const reshaped_image_hidden_states = kwargs . image_features . view ( - 1 , vision_hidden_size ) ;
3643
3637
3644
- _merge_input_ids_with_image_features ( {
3645
- inputs_embeds,
3646
- image_features,
3647
- input_ids,
3648
- attention_mask,
3649
- } ) {
3650
-
3651
- // @ts -expect-error TS2339
3652
- const image_token_index = this . config . image_token_index ;
3653
-
3654
- const idsList = input_ids . tolist ( ) ;
3655
-
3656
- // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number
3657
- const indexOfImage = idsList . map ( x => x . findIndex ( x => x == image_token_index ) ) ;
3658
-
3659
- const noImages = indexOfImage . every ( x => x === - 1 ) ;
3660
- const allImages = indexOfImage . every ( x => x !== - 1 ) ;
3661
- if ( ! noImages && ! allImages ) {
3662
- // Check for padding reasons
3663
- throw new Error ( 'Every input should contain either 0 or 1 image token.' ) ;
3664
- }
3665
-
3666
- if ( noImages ) {
3667
- return {
3668
- inputs_embeds,
3669
- attention_mask,
3670
- }
3671
- }
3672
-
3673
- const stacked = [ ] ;
3674
- const stacked_attention_mask = [ ] ;
3675
- for ( let i = 0 ; i < indexOfImage . length ; ++ i ) {
3676
- const index = indexOfImage [ i ] ;
3677
-
3678
- const e = inputs_embeds [ i ] ;
3679
- const im = image_features [ i ] ;
3680
- const am = attention_mask [ i ] ;
3681
- stacked . push (
3682
- cat ( [
3683
- e . slice ( [ 0 , index ] ) ,
3684
- im ,
3685
- e . slice ( [ index + 1 , e . dims [ 0 ] ] ) ,
3686
- ] , 0 )
3687
- ) ;
3688
-
3689
- stacked_attention_mask . push (
3690
- cat ( [
3691
- am . slice ( [ 0 , index ] ) ,
3692
- ones ( [ im . dims [ 0 ] ] ) ,
3693
- am . slice ( [ index + 1 , am . dims [ 0 ] ] )
3694
- ] , 0 )
3695
- )
3696
- }
3697
-
3698
- return {
3699
- inputs_embeds : stack ( stacked , 0 ) ,
3700
- attention_mask : stack ( stacked_attention_mask , 0 ) ,
3701
- }
3638
+ return default_merge_input_ids_with_image_features ( {
3639
+ // @ts -ignore
3640
+ image_token_id : this . config . image_token_index ,
3641
+ ...kwargs ,
3642
+ image_features : reshaped_image_hidden_states ,
3643
+ } )
3702
3644
}
3703
3645
}
3704
3646
//////////////////////////////////////////////////
@@ -3839,6 +3781,20 @@ export class PaliGemmaForConditionalGeneration extends PaliGemmaPreTrainedModel
3839
3781
}
3840
3782
}
3841
3783
3784
+ export class LlavaQwen2ForCausalLM extends LlavaPreTrainedModel {
3785
+ _merge_input_ids_with_image_features ( kwargs ) {
3786
+ const vision_hidden_size = kwargs . image_features . dims . at ( - 1 ) ;
3787
+ const reshaped_image_hidden_states = kwargs . image_features . view ( - 1 , vision_hidden_size ) ;
3788
+
3789
+ return default_merge_input_ids_with_image_features ( {
3790
+ // @ts -ignore
3791
+ image_token_id : this . config . image_token_index ,
3792
+ ...kwargs ,
3793
+ image_features : reshaped_image_hidden_states ,
3794
+ } )
3795
+ }
3796
+ }
3797
+
3842
3798
//////////////////////////////////////////////////
3843
3799
// Idefics3 Models
3844
3800
export class Idefics3PreTrainedModel extends PreTrainedModel {
@@ -7842,6 +7798,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
7842
7798
[ 'idefics3' , [ 'Idefics3ForConditionalGeneration' , Idefics3ForConditionalGeneration ] ] ,
7843
7799
[ 'smolvlm' , [ 'SmolVLMForConditionalGeneration' , SmolVLMForConditionalGeneration ] ] ,
7844
7800
[ 'paligemma' , [ 'PaliGemmaForConditionalGeneration' , PaliGemmaForConditionalGeneration ] ] ,
7801
+ [ 'llava_qwen2' , [ 'LlavaQwen2ForCausalLM' , LlavaQwen2ForCausalLM ] ] ,
7845
7802
] ) ;
7846
7803
7847
7804
const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map ( [
0 commit comments