diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py index 6df2faab52a8d..978bed42b2e4d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py @@ -21,7 +21,7 @@ from collections.abc import Iterable from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from botocore.exceptions import ClientError @@ -120,3 +120,204 @@ def get_import_status(self, import_arn: str) -> tuple[str, str | None, str | Non f"S3 import into Dynamodb job not found. Import ARN: {import_arn}" ) from e raise + + def get_item( + self, + table_name: str, + key: dict[str, Any], + ) -> dict[str, Any] | None: + """ + Retrieve a single item from a DynamoDB table by primary key. + + Uses the boto3 resource API so keys and attribute values are plain + Python types (str, int, Decimal, …) rather than the low-level typed + format (``{"S": "value"}``). + + .. seealso:: + - :external+boto3:py:meth:`DynamoDB.Table.get_item` + + :param table_name: Name of the DynamoDB table. + :param key: Primary key of the item, e.g. ``{"pk": "value"}`` or + ``{"pk": "value", "sk": "range_value"}``. + :return: The item as a plain dict, or ``None`` if not found. + """ + self.log.debug("Getting item with key %s from table %s", key, table_name) + try: + table = self.get_conn().Table(table_name) + response = table.get_item(Key=key) + item = response.get("Item") + if item is None: + self.log.info("Item with key %s not found in table %s", key, table_name) + return item + except ClientError as e: + self.log.error("Failed to get item from %s: %s", table_name, e) + raise + + def put_item( + self, + table_name: str, + item: dict[str, Any], + condition_expression: str | None = None, + ) -> dict[str, Any]: + """ + Insert or replace an item in a DynamoDB table. + + .. seealso:: + - :external+boto3:py:meth:`DynamoDB.Table.put_item` + + :param table_name: Name of the DynamoDB table. + :param item: Item attributes as a plain dict, + e.g. ``{"pk": "value", "status": "pending"}``. + :param condition_expression: Optional condition expression string. + :return: The raw response from DynamoDB. + """ + self.log.debug("Putting item into table %s", table_name) + try: + table = self.get_conn().Table(table_name) + params: dict[str, Any] = {"Item": item} + if condition_expression: + params["ConditionExpression"] = condition_expression + response = table.put_item(**params) + self.log.info("Successfully put item into table %s", table_name) + return response + except ClientError as e: + self.log.error("Failed to put item into %s: %s", table_name, e) + raise + + def update_item( + self, + table_name: str, + key: dict[str, Any], + update_expression: str, + expression_attribute_values: dict[str, Any], + expression_attribute_names: dict[str, str] | None = None, + condition_expression: str | None = None, + ) -> dict[str, Any] | None: + """ + Update attributes of an existing item in a DynamoDB table. + + Uses the boto3 resource API so values are plain Python types. + + .. seealso:: + - :external+boto3:py:meth:`DynamoDB.Table.update_item` + + :param table_name: Name of the DynamoDB table. + :param key: Primary key of the item to update. + :param update_expression: Update expression, e.g. + ``"SET #s = :status, updated_at = :ts"``. + :param expression_attribute_values: Substitution values for the + expression, e.g. ``{":status": "done", ":ts": "2024-01-01"}``. + :param expression_attribute_names: Optional name aliases for reserved + words, e.g. ``{"#s": "status"}``. + :param condition_expression: Optional condition expression string. + :return: The updated item attributes, or ``None`` if the update + returned no attributes. + """ + self.log.debug("Updating item with key %s in table %s", key, table_name) + try: + table = self.get_conn().Table(table_name) + params: dict[str, Any] = { + "Key": key, + "UpdateExpression": update_expression, + "ExpressionAttributeValues": expression_attribute_values, + "ReturnValues": "ALL_NEW", + } + if expression_attribute_names: + params["ExpressionAttributeNames"] = expression_attribute_names + if condition_expression: + params["ConditionExpression"] = condition_expression + response = table.update_item(**params) + self.log.info("Successfully updated item in table %s", table_name) + return response.get("Attributes") + except ClientError as e: + self.log.error("Failed to update item in %s: %s", table_name, e) + raise + + def delete_item( + self, + table_name: str, + key: dict[str, Any], + condition_expression: str | None = None, + ) -> dict[str, Any] | None: + """ + Delete an item from a DynamoDB table. + + .. seealso:: + - :external+boto3:py:meth:`DynamoDB.Table.delete_item` + + :param table_name: Name of the DynamoDB table. + :param key: Primary key of the item to delete. + :param condition_expression: Optional condition expression string. + :return: The deleted item's attributes, or ``None`` if the item did + not exist. + """ + self.log.debug("Deleting item with key %s from table %s", key, table_name) + try: + table = self.get_conn().Table(table_name) + params: dict[str, Any] = {"Key": key, "ReturnValues": "ALL_OLD"} + if condition_expression: + params["ConditionExpression"] = condition_expression + response = table.delete_item(**params) + self.log.info("Successfully deleted item from table %s", table_name) + return response.get("Attributes") + except ClientError as e: + self.log.error("Failed to delete item from %s: %s", table_name, e) + raise + + def query( + self, + table_name: str, + key_condition_expression: Any, + expression_attribute_values: dict[str, Any] | None = None, + expression_attribute_names: dict[str, str] | None = None, + filter_expression: Any | None = None, + index_name: str | None = None, + limit: int | None = None, + ) -> list[dict[str, Any]]: + """ + Query items from a DynamoDB table or secondary index. + + Uses the boto3 resource API so values are plain Python types. + ``key_condition_expression`` accepts either a string expression or a + boto3 :py:class:`boto3.dynamodb.conditions.ConditionBase` object + (e.g. ``Key("pk").eq("value")``). + + .. seealso:: + - :external+boto3:py:meth:`DynamoDB.Table.query` + - :external+boto3:py:class:`boto3.dynamodb.conditions.Key` + + :param table_name: Name of the DynamoDB table. + :param key_condition_expression: Key condition — string or + ``boto3.dynamodb.conditions.ConditionBase``. + :param expression_attribute_values: Substitution values (required when + using string expressions, omit when using condition objects). + :param expression_attribute_names: Optional name aliases for reserved + words. + :param filter_expression: Optional filter applied after the query — + string or ``boto3.dynamodb.conditions.ConditionBase``. + :param index_name: Name of a secondary index to query. + :param limit: Maximum number of items to evaluate (see DynamoDB + ``Limit`` semantics — this is not a guaranteed page size). + :return: List of matching items as plain dicts. + """ + self.log.debug("Querying table %s", table_name) + try: + table = self.get_conn().Table(table_name) + params: dict[str, Any] = {"KeyConditionExpression": key_condition_expression} + if expression_attribute_values: + params["ExpressionAttributeValues"] = expression_attribute_values + if expression_attribute_names: + params["ExpressionAttributeNames"] = expression_attribute_names + if filter_expression is not None: + params["FilterExpression"] = filter_expression + if index_name: + params["IndexName"] = index_name + if limit: + params["Limit"] = limit + response = table.query(**params) + items = response.get("Items", []) + self.log.info("Query on table %s returned %d items", table_name, len(items)) + return items + except ClientError as e: + self.log.error("Failed to query table %s: %s", table_name, e) + raise diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_dynamodb.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_dynamodb.py index f39f426278a2e..4e476bd1285ec 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_dynamodb.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_dynamodb.py @@ -147,3 +147,100 @@ def test_hook_has_import_waiters(self): hook = DynamoDBHook(aws_conn_id="aws_default") waiter = hook.get_waiter("import_table") assert waiter is not None + + @mock_aws + def test_get_item_returns_item(self): + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_get", + KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + hook.get_conn().Table("test_get").put_item(Item={"pk": "abc", "status": "pending"}) + + item = hook.get_item("test_get", {"pk": "abc"}) + assert item == {"pk": "abc", "status": "pending"} + + @mock_aws + def test_get_item_returns_none_when_missing(self): + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_get_missing", + KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + assert hook.get_item("test_get_missing", {"pk": "does-not-exist"}) is None + + @mock_aws + def test_put_item(self): + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_put", + KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + hook.put_item("test_put", {"pk": "xyz", "value": "hello"}) + item = hook.get_item("test_put", {"pk": "xyz"}) + assert item == {"pk": "xyz", "value": "hello"} + + @mock_aws + def test_update_item(self): + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_update", + KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + hook.put_item("test_update", {"pk": "u1", "status": "pending"}) + updated = hook.update_item( + table_name="test_update", + key={"pk": "u1"}, + update_expression="SET #s = :s", + expression_attribute_values={":s": "approved"}, + expression_attribute_names={"#s": "status"}, + ) + assert updated["status"] == "approved" + + @mock_aws + def test_delete_item(self): + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_delete", + KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + hook.put_item("test_delete", {"pk": "d1", "data": "some-value"}) + deleted = hook.delete_item("test_delete", {"pk": "d1"}) + assert deleted == {"pk": "d1", "data": "some-value"} + assert hook.get_item("test_delete", {"pk": "d1"}) is None + + @mock_aws + def test_query(self): + from boto3.dynamodb.conditions import Key + + hook = DynamoDBHook(aws_conn_id="aws_default", region_name="us-east-1") + hook.get_conn().create_table( + TableName="test_query", + KeySchema=[ + {"AttributeName": "pk", "KeyType": "HASH"}, + {"AttributeName": "sk", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "pk", "AttributeType": "S"}, + {"AttributeName": "sk", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = hook.get_conn().Table("test_query") + table.put_item(Item={"pk": "p1", "sk": "2024-01", "val": "a"}) + table.put_item(Item={"pk": "p1", "sk": "2024-02", "val": "b"}) + table.put_item(Item={"pk": "p2", "sk": "2024-01", "val": "c"}) + + items = hook.query(table_name="test_query", key_condition_expression=Key("pk").eq("p1")) + assert len(items) == 2 + assert all(i["pk"] == "p1" for i in items)