Method: Informers::PreTrainedModel#initialize

Defined in:
lib/informers/models.rb

#initialize(config, session) ⇒ PreTrainedModel

Returns a new instance of PreTrainedModel.



74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/informers/models.rb', line 74

def initialize(config, session)
  super()

  @config = config
  @session = session

  @output_names = nil

  model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
  model_type = MODEL_TYPE_MAPPING[model_name]

  case model_type
  when MODEL_TYPES[:DecoderOnly]
    @can_generate = true

    @run_beam = method(:decoder_run_beam)
    @get_start_beams = method(:decoder_start_beams)
    @update_beam = method(:decoder_update_beam)
    @forward = method(:decoder_forward)

  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
    @can_generate = true

    @run_beam = method(:seq2seq_run_beam)
    @get_start_beams = method(:seq2seq_start_beams)
    @update_beam = method(:seq2seq_update_beam)
    @forward = method(:seq2seq_forward)

  when MODEL_TYPES[:EncoderDecoder]
    @forward = method(:encoder_forward)

  else
    @forward = method(:encoder_forward)
  end
end