Class: Langchain::LLM::GoogleVertexAi
- Defined in:
- lib/langchain/llm/google_vertex_ai.rb
Overview
Wrapper around the Google Vertex AI APIs: cloud.google.com/vertex-ai?hl=en
Gem requirements:
gem "google-apis-aiplatform_v1", "~> 0.7"
Usage:
google_palm = Langchain::LLM::GoogleVertexAi.new(project_id: ENV["GOOGLE_VERTEX_AI_PROJECT_ID"])
Constant Summary collapse
- DEFAULTS =
{ temperature: 0.1, # 0.1 is the default in the API, quite low ("grounded") max_output_tokens: 1000, top_p: 0.8, top_k: 40, dimensions: 768, completion_model_name: "text-bison", # Optional: tect-bison@001 embeddings_model_name: "textembedding-gecko" }.freeze
Instance Attribute Summary collapse
-
#client ⇒ Object
readonly
Google Cloud has a project id and a specific region of deployment.
-
#project_id ⇒ Object
readonly
Google Cloud has a project id and a specific region of deployment.
-
#region ⇒ Object
readonly
Google Cloud has a project id and a specific region of deployment.
Instance Method Summary collapse
-
#complete(prompt:, **params) ⇒ Langchain::LLM::GooglePalmResponse
Generate a completion for a given prompt.
-
#embed(text:) ⇒ Langchain::LLM::GoogleVertexAiResponse
Generate an embedding for a given text.
-
#initialize(project_id:, default_options: {}) ⇒ GoogleVertexAi
constructor
A new instance of GoogleVertexAi.
-
#summarize(text:) ⇒ String
Generate a summarization for a given text.
Methods inherited from Base
Methods included from DependencyHelper
Constructor Details
#initialize(project_id:, default_options: {}) ⇒ GoogleVertexAi
Returns a new instance of GoogleVertexAi.
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 31 def initialize(project_id:, default_options: {}) depends_on "google-apis-aiplatform_v1" @project_id = project_id @region = .fetch :region, "us-central1" @client = Google::Apis::AiplatformV1::AiplatformService.new # TODO: Adapt for other regions; Pass it in via the constructor # For the moment only us-central1 available so no big deal. @client.root_url = "https://#{@region}-aiplatform.googleapis.com/" @client. = Google::Auth.get_application_default @defaults = DEFAULTS.merge() end |
Instance Attribute Details
#client ⇒ Object (readonly)
Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.
29 30 31 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29 def client @client end |
#project_id ⇒ Object (readonly)
Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.
29 30 31 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29 def project_id @project_id end |
#region ⇒ Object (readonly)
Google Cloud has a project id and a specific region of deployment. For GenAI-related things, a safe choice is us-central1.
29 30 31 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 29 def region @region end |
Instance Method Details
#complete(prompt:, **params) ⇒ Langchain::LLM::GooglePalmResponse
Generate a completion for a given prompt
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 73 def complete(prompt:, **params) default_params = { prompt: prompt, temperature: @defaults[:temperature], top_k: @defaults[:top_k], top_p: @defaults[:top_p], max_output_tokens: @defaults[:max_output_tokens], model: @defaults[:completion_model_name] } if params[:stop_sequences] default_params[:stop_sequences] = params.delete(:stop_sequences) end if params[:max_output_tokens] default_params[:max_output_tokens] = params.delete(:max_output_tokens) end # to be tested temperature = params.delete(:temperature) || @defaults[:temperature] max_output_tokens = default_params.fetch(:max_output_tokens, @defaults[:max_output_tokens]) default_params.merge!(params) # response = client.generate_text(**default_params) request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new \ instances: [{ prompt: prompt # key used to be :content, changed to :prompt }], parameters: { temperature: temperature, maxOutputTokens: max_output_tokens, topP: 0.8, topK: 40 } response = client.predict_project_location_publisher_model \ "projects/#{project_id}/locations/us-central1/publishers/google/models/#{@defaults[:completion_model_name]}", request Langchain::LLM::GoogleVertexAiResponse.new(response, model: default_params[:model]) end |
#embed(text:) ⇒ Langchain::LLM::GoogleVertexAiResponse
Generate an embedding for a given text
53 54 55 56 57 58 59 60 61 62 63 64 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 53 def (text:) content = [{content: text}] request = Google::Apis::AiplatformV1::GoogleCloudAiplatformV1PredictRequest.new(instances: content) api_path = "projects/#{@project_id}/locations/us-central1/publishers/google/models/#{@defaults[:embeddings_model_name]}" # puts("api_path: #{api_path}") response = client.predict_project_location_publisher_model(api_path, request) Langchain::LLM::GoogleVertexAiResponse.new(response.to_h, model: @defaults[:embeddings_model_name]) end |
#summarize(text:) ⇒ String
Generate a summarization for a given text
TODO(ricc): add params for Temp, topP, topK, MaxTokens and have it default to these 4 values.
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# File 'lib/langchain/llm/google_vertex_ai.rb', line 123 def summarize(text:) prompt_template = Langchain::Prompt.load_from_path( file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml") ) prompt = prompt_template.format(text: text) complete( prompt: prompt, # For best temperature, topP, topK, MaxTokens for summarization: see # https://cloud.google.com/vertex-ai/docs/samples/aiplatform-sdk-summarization temperature: 0.2, top_p: 0.95, top_k: 40, # Most models have a context length of 2048 tokens (except for the newest models, which support 4096). max_output_tokens: 256 ) end |