Spaces:
Runtime error
Runtime error
| # mypy: ignore-errors | |
| import json | |
| from typing import Any | |
| import boto3 | |
| from llama_index.core.base.embeddings.base import BaseEmbedding | |
| from pydantic import Field, PrivateAttr | |
| class SagemakerEmbedding(BaseEmbedding): | |
| """Sagemaker Embedding Endpoint. | |
| To use, you must supply the endpoint name from your deployed | |
| Sagemaker embedding model & the region where it is deployed. | |
| To authenticate, the AWS client uses the following methods to | |
| automatically load credentials: | |
| https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
| If a specific credential profile should be used, you must pass | |
| the name of the profile from the ~/.aws/credentials file that is to be used. | |
| Make sure the credentials / roles used have the required policies to | |
| access the Sagemaker endpoint. | |
| See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html | |
| """ | |
| endpoint_name: str = Field(description="") | |
| _boto_client: Any = boto3.client( | |
| "sagemaker-runtime", | |
| ) # TODO make it an optional field | |
| _async_not_implemented_warned: bool = PrivateAttr(default=False) | |
| def class_name(cls) -> str: | |
| return "SagemakerEmbedding" | |
| def _async_not_implemented_warn_once(self) -> None: | |
| if not self._async_not_implemented_warned: | |
| print("Async embedding not available, falling back to sync method.") | |
| self._async_not_implemented_warned = True | |
| def _embed(self, sentences: list[str]) -> list[list[float]]: | |
| request_params = { | |
| "inputs": sentences, | |
| } | |
| resp = self._boto_client.invoke_endpoint( | |
| EndpointName=self.endpoint_name, | |
| Body=json.dumps(request_params), | |
| ContentType="application/json", | |
| ) | |
| response_body = resp["Body"] | |
| response_str = response_body.read().decode("utf-8") | |
| response_json = json.loads(response_str) | |
| return response_json["vectors"] | |
| def _get_query_embedding(self, query: str) -> list[float]: | |
| """Get query embedding.""" | |
| return self._embed([query])[0] | |
| async def _aget_query_embedding(self, query: str) -> list[float]: | |
| # Warn the user that sync is being used | |
| self._async_not_implemented_warn_once() | |
| return self._get_query_embedding(query) | |
| async def _aget_text_embedding(self, text: str) -> list[float]: | |
| # Warn the user that sync is being used | |
| self._async_not_implemented_warn_once() | |
| return self._get_text_embedding(text) | |
| def _get_text_embedding(self, text: str) -> list[float]: | |
| """Get text embedding.""" | |
| return self._embed([text])[0] | |
| def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: | |
| """Get text embeddings.""" | |
| return self._embed(texts) | |