Class: Langchain::LLM::GooglePalm

Inherits:
Base
  • Object
show all
Defined in:
lib/langchain/llm/google_palm.rb

Overview

Wrapper around the Google PaLM (Pathways Language Model) APIs: ai.google/build/machine-learning/

Gem requirements:

gem "google_palm_api", "~> 0.1.3"

Usage:

google_palm = Langchain::LLM::GooglePalm.new(api_key: ENV["GOOGLE_PALM_API_KEY"])

Constant Summary collapse

DEFAULTS =
{
  temperature: 0.0,
  dimensions: 768, # This is what the `embedding-gecko-001` model generates
  completion_model_name: "text-bison-001",
  chat_completion_model_name: "chat-bison-001",
  embeddings_model_name: "embedding-gecko-001"
}.freeze
LENGTH_VALIDATOR =
Langchain::Utils::TokenLength::GooglePalmValidator
ROLE_MAPPING =
{
  "assistant" => "ai"
}

Instance Attribute Summary collapse

Attributes inherited from Base

#client

Instance Method Summary collapse

Methods inherited from Base

#default_dimensions

Methods included from DependencyHelper

#depends_on

Constructor Details

#initialize(api_key:, default_options: {}) ⇒ GooglePalm

Returns a new instance of GooglePalm.



28
29
30
31
32
33
# File 'lib/langchain/llm/google_palm.rb', line 28

def initialize(api_key:, default_options: {})
  depends_on "google_palm_api"

  @client = ::GooglePalmApi::Client.new(api_key: api_key)
  @defaults = DEFAULTS.merge(default_options)
end

Instance Attribute Details

#defaultsObject (readonly)

Returns the value of attribute defaults.



26
27
28
# File 'lib/langchain/llm/google_palm.rb', line 26

def defaults
  @defaults
end

Instance Method Details

#chat(prompt: "", messages: [], context: "", examples: [], **options) ⇒ Langchain::LLM::GooglePalmResponse

Generate a chat completion for a given prompt

Parameters:

  • prompt (String) (defaults to: "")

    The prompt to generate a chat completion for

  • messages (Array<Hash>) (defaults to: [])

    The messages that have been sent in the conversation

  • context (String) (defaults to: "")

    An initial context to provide as a system message, ie “You are RubyGPT, a helpful chat bot for helping people learn Ruby”

  • examples (Array<Hash>) (defaults to: [])

    Examples of messages to provide to the model. Useful for Few-Shot Prompting

  • options (Hash)

    extra parameters passed to GooglePalmAPI::Client#generate_chat_message

Returns:

Raises:

  • (ArgumentError)


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# File 'lib/langchain/llm/google_palm.rb', line 88

def chat(prompt: "", messages: [], context: "", examples: [], **options)
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?

  default_params = {
    temperature: @defaults[:temperature],
    model: @defaults[:chat_completion_model_name],
    context: context,
    messages: compose_chat_messages(prompt: prompt, messages: messages),
    examples: compose_examples(examples)
  }

  # chat-bison-001 is the only model that currently supports countMessageTokens functions
  LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)

  if options[:stop_sequences]
    default_params[:stop] = options.delete(:stop_sequences)
  end

  if options[:max_tokens]
    default_params[:max_output_tokens] = options.delete(:max_tokens)
  end

  default_params.merge!(options)

  response = client.generate_chat_message(**default_params)
  raise "GooglePalm API returned an error: #{response}" if response.dig("error")

  Langchain::LLM::GooglePalmResponse.new response,
    model: default_params[:model]
  # TODO: Pass in prompt_tokens: prompt_tokens
end

#complete(prompt:, **params) ⇒ Langchain::LLM::GooglePalmResponse

Generate a completion for a given prompt

Parameters:

  • prompt (String)

    The prompt to generate a completion for

  • params

    extra parameters passed to GooglePalmAPI::Client#generate_text

Returns:



55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/langchain/llm/google_palm.rb', line 55

def complete(prompt:, **params)
  default_params = {
    prompt: prompt,
    temperature: @defaults[:temperature],
    model: @defaults[:completion_model_name]
  }

  if params[:stop_sequences]
    default_params[:stop_sequences] = params.delete(:stop_sequences)
  end

  if params[:max_tokens]
    default_params[:max_output_tokens] = params.delete(:max_tokens)
  end

  default_params.merge!(params)

  response = client.generate_text(**default_params)

  Langchain::LLM::GooglePalmResponse.new response,
    model: default_params[:model]
end

#embed(text:) ⇒ Langchain::LLM::GooglePalmResponse

Generate an embedding for a given text

Parameters:

  • text (String)

    The text to generate an embedding for

Returns:



41
42
43
44
45
46
# File 'lib/langchain/llm/google_palm.rb', line 41

def embed(text:)
  response = client.embed(text: text)

  Langchain::LLM::GooglePalmResponse.new response,
    model: @defaults[:embeddings_model_name]
end

#summarize(text:) ⇒ String

Generate a summarization for a given text

Parameters:

  • text (String)

    The text to generate a summarization for

Returns:

  • (String)

    The summarization



126
127
128
129
130
131
132
133
134
135
136
137
138
# File 'lib/langchain/llm/google_palm.rb', line 126

def summarize(text:)
  prompt_template = Langchain::Prompt.load_from_path(
    file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
  )
  prompt = prompt_template.format(text: text)

  complete(
    prompt: prompt,
    temperature: @defaults[:temperature],
    # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
    max_tokens: 256
  )
end