Class: Informers::VisionEncoderDecoderModel

Inherits:
PreTrainedModel show all
Defined in:
lib/informers/models.rb

Constant Summary collapse

MAIN_INPUT_NAME =
:pixel_values

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from PreTrainedModel

#call, construct_session, from_pretrained, #generate

Constructor Details

#initialize(config, session, decoder_merged_session, generation_config) ⇒ VisionEncoderDecoderModel

Returns a new instance of VisionEncoderDecoderModel.



1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
# File 'lib/informers/models.rb', line 1117

def initialize(config, session, decoder_merged_session, generation_config)
  super(config, session)
  @decoder_merged_session = decoder_merged_session
  @generation_config = generation_config

  # Extract configs
  encoder_config = @config["encoder"]
  decoder_config = @config["decoder"]

  # Validate encoder
  encoder_model_type = encoder_config["model_type"]
  encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
  if !encoder_model
    warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
  end

  # Validate decoder
  decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
  if !decoder_model
    raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
  end

  decoder_model_class = decoder_model[1]
  decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)

  @add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
  if @add_encoder_pkv
    # Decoder is part of an encoder-decoder model
    @num_decoder_layers = decoder.num_decoder_layers
    @num_decoder_heads = decoder.num_decoder_heads
    @decoder_dim_kv = decoder.decoder_dim_kv

    @num_encoder_layers = decoder.num_encoder_layers
    @num_encoder_heads = decoder.num_encoder_heads
    @encoder_dim_kv = decoder.encoder_dim_kv
  else
    # Decoder is a decoder-only model
    @num_layers = decoder.num_layers
    @num_heads = decoder.num_heads
    @dim_kv = decoder.dim_kv
  end
end