Method: Informers::PreTrainedModel#initialize

Defined in:
lib/informers/models.rb

#initialize(config, session) ⇒ PreTrainedModel

Returns a new instance of PreTrainedModel.



68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# File 'lib/informers/models.rb', line 68

def initialize(config, session)
  super()

  @config = config
  @session = session

  @output_names = nil

  model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
  model_type = MODEL_TYPE_MAPPING[model_name]

  case model_type
  when MODEL_TYPES[:DecoderOnly]
    @can_generate = true

    @run_beam = method(:decoder_run_beam)
    @get_start_beams = method(:decoder_start_beams)
    @update_beam = method(:decoder_update_beam)
    @forward = method(:decoder_forward)
  when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
    @can_generate = true

    @run_beam = method(:seq2seq_run_beam)
    @get_start_beams = method(:seq2seq_start_beams)
    @update_beam = method(:seq2seq_update_beam)
    @forward = method(:seq2seq_forward)
  when MODEL_TYPES[:EncoderDecoder]
    raise Todo
  else
    @forward = method(:encoder_forward)
  end
end