Bases: BaseRerankDriver
Amazon Bedrock Rerank Driver for executing cross-encoder text scoring calculations serverlessly.
Attributes:
| Name |
Type |
Description |
model |
str
|
The specific model identifier ARN (e.g., 'cohere.rerank-v3-5:0').
|
top_n |
int | None
|
Optional maximum threshold number of sorted results to return.
|
session |
Session
|
An authenticated boto3.Session connection profile container.
|
Source code in griptape/drivers/rerank/amazon_bedrock_rerank_driver.py
| @define(kw_only=True)
class AmazonBedrockRerankDriver(BaseRerankDriver):
"""Amazon Bedrock Rerank Driver for executing cross-encoder text scoring calculations serverlessly.
Attributes:
model: The specific model identifier ARN (e.g., 'cohere.rerank-v3-5:0').
top_n: Optional maximum threshold number of sorted results to return.
session: An authenticated boto3.Session connection profile container.
"""
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()))
model: str = field(default="cohere.rerank-v3-5:0", metadata={"serializable": True})
_client: AgentsforBedrockRuntimeClient | None = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)
top_n: int | None = field(default=None, metadata={"serializable": True})
@lazy_property()
def client(self) -> AgentsforBedrockRuntimeClient:
return self.session.client("bedrock-agent-runtime")
@property
def model_arn(self) -> str:
return f"arn:{self.client.meta.partition}:bedrock:{self.client.meta.region_name}::foundation-model/{self.model}"
@staticmethod
def _post_process(response: list[RerankResultTypeDef], artifacts: list[TextArtifact]) -> list[TextArtifact]:
return [artifacts[res["index"]] for res in response]
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
truthy_artifacts = [artifact for artifact in artifacts if artifact]
if not truthy_artifacts:
return []
sources: list[RerankSourceTypeDef] = [
{
"type": "INLINE",
"inlineDocumentSource": {"type": "TEXT", "textDocument": {"text": artifact.to_text()}},
}
for artifact in truthy_artifacts
]
bedrock_reranking_configuration: BedrockRerankingConfigurationTypeDef = {
"modelConfiguration": {"modelArn": self.model_arn}
}
if self.top_n is not None:
bedrock_reranking_configuration["numberOfResults"] = self.top_n
reranking_config: RerankingConfigurationTypeDef = {
"type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": bedrock_reranking_configuration,
}
rerank_params: RerankRequestTypeDef = {
"queries": [{"type": "TEXT", "textQuery": {"text": query}}],
"sources": sources,
"rerankingConfiguration": reranking_config,
}
response = self.client.rerank(**rerank_params)
all_results = response["results"]
while next_token := response.get("nextToken"):
rerank_params["nextToken"] = next_token
response = self.client.rerank(**rerank_params)
all_results.extend(response["results"])
return self._post_process(all_results, truthy_artifacts)
|
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute
instance-attribute
model = field(default='cohere.rerank-v3-5:0', metadata={'serializable': True})
class-attribute
instance-attribute
session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()))
class-attribute
instance-attribute
top_n = field(default=None, metadata={'serializable': True})
class-attribute
instance-attribute
_post_process(response, artifacts)
staticmethod
Source code in griptape/drivers/rerank/amazon_bedrock_rerank_driver.py
| @staticmethod
def _post_process(response: list[RerankResultTypeDef], artifacts: list[TextArtifact]) -> list[TextArtifact]:
return [artifacts[res["index"]] for res in response]
|
client()
Source code in griptape/drivers/rerank/amazon_bedrock_rerank_driver.py
| @lazy_property()
def client(self) -> AgentsforBedrockRuntimeClient:
return self.session.client("bedrock-agent-runtime")
|
run(query, artifacts)
Source code in griptape/drivers/rerank/amazon_bedrock_rerank_driver.py
| def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
truthy_artifacts = [artifact for artifact in artifacts if artifact]
if not truthy_artifacts:
return []
sources: list[RerankSourceTypeDef] = [
{
"type": "INLINE",
"inlineDocumentSource": {"type": "TEXT", "textDocument": {"text": artifact.to_text()}},
}
for artifact in truthy_artifacts
]
bedrock_reranking_configuration: BedrockRerankingConfigurationTypeDef = {
"modelConfiguration": {"modelArn": self.model_arn}
}
if self.top_n is not None:
bedrock_reranking_configuration["numberOfResults"] = self.top_n
reranking_config: RerankingConfigurationTypeDef = {
"type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": bedrock_reranking_configuration,
}
rerank_params: RerankRequestTypeDef = {
"queries": [{"type": "TEXT", "textQuery": {"text": query}}],
"sources": sources,
"rerankingConfiguration": reranking_config,
}
response = self.client.rerank(**rerank_params)
all_results = response["results"]
while next_token := response.get("nextToken"):
rerank_params["nextToken"] = next_token
response = self.client.rerank(**rerank_params)
all_results.extend(response["results"])
return self._post_process(all_results, truthy_artifacts)
|