Class: Langchain::Utils::TokenLength::GooglePalmValidator

Inherits:
BaseValidator
  • Object
show all
Defined in:
lib/langchain/utils/token_length/google_palm_validator.rb

Overview

This class is meant to validate the length of the text passed in to Google Palm’s API. It is used to validate the token length before the API call is made

Constant Summary collapse

TOKEN_LIMITS =
{
  # Source:
  # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage

  # chat-bison-001 is the only model that currently supports countMessageTokens functions
  "chat-bison-001" => {
    "input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
    "output_token_limit" => 1024
  }
  # "text-bison-001" => {
  #   "input_token_limit" => 8196,
  #   "output_token_limit" => 1024
  # },
  # "embedding-gecko-001" => {
  #   "input_token_limit" => 1024
  # }
}.freeze

Class Method Summary collapse

Methods inherited from BaseValidator

limit_exceeded_exception, validate_max_tokens!

Class Method Details

.token_length(text, model_name = "chat-bison-001", options = {}) ⇒ Integer

Calculate token length for a given text and model name

Parameters:

  • text (String)

    The text to calculate the token length for

  • model_name (String) (defaults to: "chat-bison-001")

    The model name to validate against

  • options (Hash) (defaults to: {})

    the options to create a message with

Options Hash (options):

  • :llm (Langchain::LLM:GooglePalm)

    The Langchain::LLM:GooglePalm instance

Returns:

  • (Integer)

    The token length of the text

Raises:



38
39
40
41
42
43
44
# File 'lib/langchain/utils/token_length/google_palm_validator.rb', line 38

def self.token_length(text, model_name = "chat-bison-001", options = {})
  response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)

  raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?

  response.dig("tokenCount")
end

.token_length_from_messages(messages, model_name, options = {}) ⇒ Object



46
47
48
# File 'lib/langchain/utils/token_length/google_palm_validator.rb', line 46

def self.token_length_from_messages(messages, model_name, options = {})
  messages.sum { |message| token_length(message.to_json, model_name, options) }
end

.token_limit(model_name) ⇒ Object



50
51
52
# File 'lib/langchain/utils/token_length/google_palm_validator.rb', line 50

def self.token_limit(model_name)
  TOKEN_LIMITS.dig(model_name, "input_token_limit")
end