Class: Langchain::LLM::GoogleVertexAi

Inherits:
Base
  • Object
show all
Defined in:
lib/langchain/llm/google_vertex_ai.rb

Overview

Wrapper around the Google Vertex AI APIs: cloud.google.com/vertex-ai?hl=en

Gem requirements:

gem "google-apis-aiplatform_v1", "~> 0.7"

Usage:

google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])

Constant Summary collapse

DEFAULTS =
{
  temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded")
  max_output_tokens: 1000,
  top_p: 0.8,
  top_k: 40,
  dimensions: 768,
  completion_model_name: "text-bison", # Optional: tect-bison@001
  embeddings_model_name: "textembedding-gecko"
}.freeze

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods inherited from Base

#chat, #default_dimensions

Methods included from DependencyHelper

#depends_on

Constructor Details

#initialize(project_id:, default_options: {}) ⇒ GoogleVertexAi

Returns a new instance of GoogleVertexAi.



31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/langchain/llm/google_vertex_ai.rb', line 31

def initialize(project_id:, default_options: {})
  depends_on "google-apis-aiplatform_v1"

  @project_id = project_id
  @region = default_options.fetch :region, "us-central1"

  @client = Google::Apis::AiplatformV1::AiplatformService.new

  # TODO: Adapt for other regions; Pass it in via the constructor
  # For the moment only us-central1 available so no big deal.
  @client.root_url = "https://#{@region}-aiplatform.googleapis.com/"
  @client.authorization = Google::Auth.get_application_default

  @defaults = DEFAULTS.merge(default_options)
end

Instance Attribute Details

#clientObject (readonly)

Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.



29
30
31
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29

def client
  @client
end

#project_idObject (readonly)

Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.



29
30
31
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29

def project_id
  @project_id
end

#regionObject (readonly)

Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.



29
30
31
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29

def region
  @region
end

Instance Method Details

#complete(prompt:, **params) ⇒ Langchain::LLM::GooglePalmResponse

Generate a completion for a given prompt

Parameters:

  • prompt (String)

    The prompt to generate a completion for

  • params

    extra parameters passed to GooglePalmAPI::Client#generate_text

Returns:



73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# File 'lib/langchain/llm/google_vertex_ai.rb', line 73

def complete(prompt:, **params)
  default_params = {
    prompt: prompt,
    temperature: @defaults[:temperature],
    top_k: @defaults[:top_k],
    top_p: @defaults[:top_p],
    max_output_tokens: @defaults[:max_output_tokens],
    model: @defaults[:completion_model_name]
  }

  if params[:stop_sequences]
    default_params[:stop_sequences] = params.delete(:stop_sequences)
  end

  if params[:max_output_tokens]
    default_params[:max_output_tokens] = params.delete(:max_output_tokens)
  end

  # to be tested
  temperature = params.delete(:temperature) || @defaults[:temperature]
  max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens])

  default_params.merge!(params)

  # response = client.generate_text(**default_params)
  request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \
    instances: [{
      prompt: prompt # key used to be :content, changed to :prompt
    }],
    parameters: {
      temperature: temperature,
      maxOutputTokens: max_output_tokens,
      topP: 0.8,
      topK: 40
    }

  response = client.predict_project_location_publisher_model \
    "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}",
    request

  Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model])
end

#embed(text:) ⇒ Langchain::LLM::GoogleVertexAiResponse

Generate an embedding for a given text

Parameters:

  • text (String)

    The text to generate an embedding for

Returns:



53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/langchain/llm/google_vertex_ai.rb', line 53

def embed(text:)
  content = [{content: text}]
  request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content)

  api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}"

  # puts("api_path: #{api_path}")

  response = client.predict_project_location_publisher_model(api_path, request)

  Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name])
end

#summarize(text:) ⇒ String

Generate a summarization for a given text

TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.

Parameters:

  • text (String)

    The text to generate a summarization for

Returns:

  • (String)

    The summarization



123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# File 'lib/langchain/llm/google_vertex_ai.rb', line 123

def summarize(text:)
  prompt_template = Langchain::Prompt.load_from_path(
    file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
  )
  prompt = prompt_template.format(text: text)

  complete(
    prompt: prompt,
    # For best temperature, topP, topK, MaxTokens for summarization: see
    # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization
    temperature: 0.2,
    top_p: 0.95,
    top_k: 40,
    # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
    max_output_tokens: 256
  )
end