Class: Transformers::SentenceTransformer

Inherits:
Object
  • Object
show all
Defined in:
lib/transformers/sentence_transformer.rb

Instance Method Summary collapse

Constructor Details

#initialize(model_id) ⇒ SentenceTransformer

Returns a new instance of SentenceTransformer.



3
4
5
6
7
# File 'lib/transformers/sentence_transformer.rb', line 3

def initialize(model_id)
  @model_id = model_id
  @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
  @model = Transformers::AutoModel.from_pretrained(model_id)
end

Instance Method Details

#encode(sentences) ⇒ Object



9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# File 'lib/transformers/sentence_transformer.rb', line 9

def encode(sentences)
  singular = sentences.is_a?(String)
  sentences = [sentences] if singular

  input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
  output = Torch.no_grad { @model.(**input) }[0]

  # TODO check modules.json
  if [
    "sentence-transformers/all-MiniLM-L6-v2",
    "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
  ].include?(@model_id)
    output = mean_pooling(output, input[:attention_mask])
    output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
  else
    output = output[0.., 0].to_a
  end

  singular ? output[0] : output
end