Class: Transformers::SentenceTransformer
- Inherits:
-
Object
- Object
- Transformers::SentenceTransformer
- Defined in:
- lib/transformers/sentence_transformer.rb
Instance Method Summary collapse
- #encode(sentences) ⇒ Object
-
#initialize(model_id) ⇒ SentenceTransformer
constructor
A new instance of SentenceTransformer.
Constructor Details
#initialize(model_id) ⇒ SentenceTransformer
Returns a new instance of SentenceTransformer.
3 4 5 6 7 |
# File 'lib/transformers/sentence_transformer.rb', line 3 def initialize(model_id) @model_id = model_id @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id) @model = Transformers::AutoModel.from_pretrained(model_id) end |
Instance Method Details
#encode(sentences) ⇒ Object
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
# File 'lib/transformers/sentence_transformer.rb', line 9 def encode(sentences) singular = sentences.is_a?(String) sentences = [sentences] if singular input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt") output = Torch.no_grad { @model.(**input) }[0] # TODO check modules.json if [ "sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" ].include?(@model_id) output = mean_pooling(output, input[:attention_mask]) output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a else output = output[0.., 0].to_a end singular ? output[0] : output end |