Skip to content
Open
64 changes: 57 additions & 7 deletions samtranslator/plugins/globals/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from typing import Any, Union

from samtranslator.model.exceptions import ExceptionWithMessage, InvalidResourceAttributeTypeException
from samtranslator.plugins.globals.merge_strategy import REPLACE, MergeOp, MergeRule
from samtranslator.public.intrinsics import is_intrinsics
from samtranslator.public.sdk.resource import SamResourceType
from samtranslator.swagger.swagger import SwaggerEditor

# Per-property merge schema. Paths not listed here default to CONCATENATE (today's behavior).
CUSTOM_STRATEGIES: dict[str, MergeRule] = {"Function.Architectures": REPLACE}


class Globals:
"""
Expand Down Expand Up @@ -307,7 +311,12 @@ def _parse(self, globals_dict): # type: ignore[no-untyped-def]
)

# Store all Global properties in a map with key being the AWS::Serverless::* resource type
_globals[resource_type] = GlobalProperties(properties)
resource_schema = {
k.removeprefix(f"{section_name}."): v
for k, v in CUSTOM_STRATEGIES.items()
if k.startswith(f"{section_name}.")
}
_globals[resource_type] = GlobalProperties(properties, schema=resource_schema)

return _globals

Expand Down Expand Up @@ -435,24 +444,26 @@ class GlobalProperties:

"""

def __init__(self, global_properties) -> None: # type: ignore[no-untyped-def]
def __init__(self, global_properties, schema=None) -> None: # type: ignore[no-untyped-def]
self.global_properties = global_properties
self.schema = schema or {}

def merge(self, local_properties): # type: ignore[no-untyped-def]
"""
Merge Global & local level properties according to the above rules

:return local_properties: Dictionary of local properties
"""
return self._do_merge(self.global_properties, local_properties) # type: ignore[no-untyped-call]
return self._do_merge(self.global_properties, local_properties, path="") # type: ignore[no-untyped-call]

def _do_merge(self, global_value, local_value): # type: ignore[no-untyped-def]
def _do_merge(self, global_value, local_value, path=""): # type: ignore[no-untyped-def]
"""
Actually perform the merge operation for the given inputs. This method is used as part of the recursion.
Therefore input values can be of any type. So is the output.

:param global_value: Global value to be merged
:param local_value: Local value to be merged
:param path: Dot-delimited path for schema lookup
:return: Merged result
"""

Expand All @@ -467,9 +478,14 @@ def _do_merge(self, global_value, local_value): # type: ignore[no-untyped-def]
return self._prefer_local(global_value, local_value) # type: ignore[no-untyped-call]

if self.TOKEN.DICT == token_global == token_local:
return self._merge_dict(global_value, local_value) # type: ignore[no-untyped-call]
return self._merge_dict(global_value, local_value, path) # type: ignore[no-untyped-call]

if self.TOKEN.LIST == token_global == token_local:
rule = self.schema.get(path)
if rule and rule.op == MergeOp.REPLACE:
return local_value
if rule and rule.op == MergeOp.MERGE_BY_KEY:
return self._merge_by_key(global_value, local_value, rule.key)
return self._merge_lists(global_value, local_value) # type: ignore[no-untyped-call]

raise TypeError(f"Unsupported type of objects. GlobalType={token_global}, LocalType={token_local}")
Expand All @@ -485,22 +501,56 @@ def _merge_lists(self, global_list, local_list): # type: ignore[no-untyped-def]

return global_list + local_list

def _merge_dict(self, global_dict, local_dict): # type: ignore[no-untyped-def]
def _merge_by_key(self, global_list: list[Any], local_list: list[Any], key: str | None) -> list[Any]:
"""
Merges two lists of dicts by a shared key field. Local entries override global entries
with the same key value. Non-dict items and items without the key are preserved.

:param global_list: Global list of dicts
:param local_list: Local list of dicts
:param key: The dict key to match on
:return: Merged list
"""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BUG] _merge_by_key produces inconsistent results when local_list contains duplicate key values, depending on whether global_list happens to contain the same key.

local_by_key = {item[key]: item for item in local_list if isinstance(item, dict) and key in item}

The dict comprehension causes the last local entry to win for any given key. Then in Pass 2, the loop iterates local_list directly and appends the first non-seen entry. The two passes disagree:

  • Global has no override for the duplicate key → Pass 2 wins → first local entry kept.
  • Global has an override for the duplicate key → Pass 1 wins via local_by_key → last local entry kept.

Trace with key="Key":

# Case A: no global override
global_list = []
local_list  = [{"Key": "env", "Value": "a"}, {"Key": "env", "Value": "b"}]
# result -> [{"Key": "env", "Value": "a"}]  (first wins)

# Case B: global has the same key
global_list = [{"Key": "env", "Value": "g"}]
local_list  = [{"Key": "env", "Value": "a"}, {"Key": "env", "Value": "b"}]
# result -> [{"Key": "env", "Value": "b"}]  (last wins, via local_by_key)

The existing test list_with_merge_by_key_deduplicates_local_duplicates covers Case A and asserts first-wins; Case B is untested and silently flips the precedence. Recommend picking one rule and using it in both passes — e.g. build local_by_key by skipping already-present keys (if item[key] not in local_by_key) so first-wins is consistent across both branches, or by iterating local_by_key.values() in Pass 2 so last-wins is consistent.

This is dormant for now (no CUSTOM_STRATEGIES entry uses MERGE_BY_KEY), but the test suite locks in the inconsistent behavior, which will be harder to change once Tags merge-by-key is enabled.

local_by_key = {item[key]: item for item in local_list if isinstance(item, dict) and key in item}
seen_keys: set[Any] = set()
result = []

# Pass 1: walk globals, replace matched keys with local override
for item in global_list:
if isinstance(item, dict) and key in item and item[key] in local_by_key:
Comment thread
vicheey marked this conversation as resolved.
result.append(local_by_key[item[key]])
seen_keys.add(item[key])
else:
result.append(item)

# Pass 2: append local items not already seen (new keys + non-dict overflow)
for item in local_list:
if isinstance(item, dict) and key in item:
if item[key] not in seen_keys:
result.append(item)
else:
result.append(item)

return result

def _merge_dict(self, global_dict, local_dict, path_prefix=""): # type: ignore[no-untyped-def]
"""
Merges the two dictionaries together

:param global_dict: Global dictionary to be merged
:param local_dict: Local dictionary to be merged
:param path_prefix: Current dot-delimited path prefix for schema lookup
:return: New merged dictionary with values shallow copied
"""

# Local has higher priority than global. So iterate over local dict and merge into global if keys are overridden
global_dict = global_dict.copy()

for key in local_dict:
child_path = f"{path_prefix}.{key}".lstrip(".")
if key in global_dict:
# Both local & global contains the same key. Let's do a merge.
global_dict[key] = self._do_merge(global_dict[key], local_dict[key]) # type: ignore[no-untyped-call]
global_dict[key] = self._do_merge(global_dict[key], local_dict[key], child_path) # type: ignore[no-untyped-call]

else:
# Key is not in globals, just in local. Copy it over
Expand Down
32 changes: 32 additions & 0 deletions samtranslator/plugins/globals/merge_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Per-property merge strategy types for the Globals merge engine."""

from dataclasses import dataclass
from enum import Enum


class MergeOp(Enum):
CONCATENATE = "concatenate"
REPLACE = "replace"
MERGE_BY_KEY = "merge_by_key"


@dataclass(frozen=True)
class MergeRule:
op: MergeOp
key: str | None = None

def __post_init__(self) -> None:
if self.op == MergeOp.MERGE_BY_KEY and not self.key:
raise ValueError("MERGE_BY_KEY requires a 'key' field")
if self.op != MergeOp.MERGE_BY_KEY and self.key is not None:
raise ValueError(f"'key' is only valid with MERGE_BY_KEY, not {self.op.value}")


# Explicit default; not needed in CUSTOM_STRATEGIES (unlisted paths already concatenate).
CONCATENATE = MergeRule(MergeOp.CONCATENATE)
REPLACE = MergeRule(MergeOp.REPLACE)


def merge_by_key(key: str) -> MergeRule:
"""Factory for MERGE_BY_KEY rules. Merges list-of-dicts by the named key field."""
return MergeRule(MergeOp.MERGE_BY_KEY, key=key)
61 changes: 60 additions & 1 deletion tests/plugins/globals/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from parameterized import parameterized
from samtranslator.model.exceptions import InvalidResourceAttributeTypeException
from samtranslator.plugins.globals.globals import GlobalProperties, Globals, InvalidGlobalsSectionException
from samtranslator.plugins.globals.merge_strategy import REPLACE, merge_by_key


class GlobalPropertiesTestCases:
Expand Down Expand Up @@ -171,6 +172,42 @@ class GlobalPropertiesTestCases:

mixed_type_inputs_must_be_handled = {"global": {"a": "b"}, "local": [1, 2, 3], "expected_output": [1, 2, 3]}

# Merge strategy: REPLACE — local list fully replaces global list (flat and nested paths).
# Add new test cases here when new rules are added to CUSTOM_STRATEGIES.
list_with_replace_strategy_must_use_local = {
"global": {"Architectures": ["x86_64"], "VpcConfig": {"SecurityGroupIds": ["sg-global"]}},
"local": {"Architectures": ["arm64"], "VpcConfig": {"SecurityGroupIds": ["sg-local"]}},
"expected_output": {"Architectures": ["arm64"], "VpcConfig": {"SecurityGroupIds": ["sg-local"]}},
"schema": {"Architectures": REPLACE, "VpcConfig.SecurityGroupIds": REPLACE},
}

# Merge strategy: MERGE_BY_KEY — deduplicates by key field, local overrides; non-dict items preserved.
list_with_merge_by_key_strategy = {
"global": {"Tags": [{"Key": "env", "Value": "dev"}, {"Key": "team", "Value": "lambda"}, "plain-string"]},
"local": {"Tags": [{"Key": "env", "Value": "prod"}, {"Key": "app", "Value": "my"}]},
"expected_output": {
"Tags": [
{"Key": "env", "Value": "prod"},
{"Key": "team", "Value": "lambda"},
"plain-string",
{"Key": "app", "Value": "my"},
]
},
"schema": {"Tags": merge_by_key("Key")},
}

# Multiple strategies applied to different properties in one merge.
multiple_strategies_applied_per_property = {
"global": {"Architectures": ["x86_64"], "Tags": [{"Key": "env", "Value": "dev"}], "Layers": ["arn:layer1"]},
"local": {"Architectures": ["arm64"], "Tags": [{"Key": "env", "Value": "prod"}], "Layers": ["arn:layer2"]},
"expected_output": {
"Architectures": ["arm64"],
"Tags": [{"Key": "env", "Value": "prod"}],
"Layers": ["arn:layer1", "arn:layer2"],
},
"schema": {"Architectures": REPLACE, "Tags": merge_by_key("Key")},
}


class TestGlobalPropertiesMerge(TestCase):
# Get all attributes of the test case object which is not a built-in method like __str__
Expand All @@ -180,7 +217,8 @@ def test_global_properties_merge(self, testcase):
if not configuration:
raise Exception("Invalid configuration for test case " + testcase)

global_properties = GlobalProperties(configuration["global"])
schema = configuration.get("schema", {})
global_properties = GlobalProperties(configuration["global"], schema=schema)
actual = global_properties.merge(configuration["local"])

self.assertEqual(actual, configuration["expected_output"])
Expand Down Expand Up @@ -532,3 +570,24 @@ def test_openapi_postprocess(self):
global_obj = Globals(self.template)
global_obj.fix_openapi_definitions(test["input"])
self.assertEqual(test["input"], test["expected"], test["name"])


class TestMergeSchemaWiring(TestCase):
"""Tests that require the full Globals(template) pipeline (schema slicing by resource type).

Add new test cases to GlobalPropertiesTestCases for merge behavior.
Only add here for wiring-specific tests (IgnoreGlobals interaction, resource-type routing).
"""

def test_ignore_globals_skips_schema(self):
"""IgnoreGlobals for a registered property means schema is never consulted."""
template = {"Globals": {"Function": {"Architectures": ["arm64"], "Runtime": "python3.12"}}}
g = Globals(template)
result = g.merge(
"AWS::Serverless::Function",
{"Architectures": ["x86_64"]},
logical_id="MyFunc",
ignore_globals=["Architectures"],
)
self.assertEqual(result["Architectures"], ["x86_64"])
self.assertEqual(result["Runtime"], "python3.12")
90 changes: 90 additions & 0 deletions tests/plugins/globals/test_merge_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Unit tests for merge_strategy.py types."""

import unittest

from parameterized import parameterized
from samtranslator.plugins.globals.merge_strategy import (
CONCATENATE,
REPLACE,
MergeOp,
MergeRule,
merge_by_key,
)


class TestMergeOp(unittest.TestCase):
@parameterized.expand(
[
("concatenate", MergeOp.CONCATENATE, "concatenate"),
("replace", MergeOp.REPLACE, "replace"),
("merge_by_key", MergeOp.MERGE_BY_KEY, "merge_by_key"),
]
)
def test_enum_values(self, _name, member, expected):
self.assertEqual(member.value, expected)


class TestMergeRule(unittest.TestCase):
@parameterized.expand(
[
("replace", MergeOp.REPLACE, None),
("concatenate", MergeOp.CONCATENATE, None),
("merge_by_key", MergeOp.MERGE_BY_KEY, "Key"),
]
)
def test_valid_creation(self, _name, op, key):
rule = MergeRule(op, key=key) if key else MergeRule(op)
self.assertEqual(rule.op, op)
self.assertEqual(rule.key, key)

@parameterized.expand(
[
("merge_by_key_no_key", MergeOp.MERGE_BY_KEY, None, "MERGE_BY_KEY requires a 'key' field"),
("replace_with_key", MergeOp.REPLACE, "Bad", "only valid with MERGE_BY_KEY"),
("concatenate_with_key", MergeOp.CONCATENATE, "Bad", "only valid with MERGE_BY_KEY"),
]
)
def test_invalid_creation_raises(self, _name, op, key, expected_msg):
with self.assertRaises(ValueError) as ctx:
MergeRule(op, key=key)
self.assertIn(expected_msg, str(ctx.exception))

def test_frozen_immutable(self):
rule = MergeRule(MergeOp.REPLACE)
with self.assertRaises(AttributeError):
rule.op = MergeOp.CONCATENATE


class TestConvenienceConstructors(unittest.TestCase):
@parameterized.expand(
[
("CONCATENATE", CONCATENATE, MergeOp.CONCATENATE, None),
("REPLACE", REPLACE, MergeOp.REPLACE, None),
("MERGE_BY_KEY", merge_by_key("Key"), MergeOp.MERGE_BY_KEY, "Key"),
]
)
def test_constructor(self, _name, rule, expected_op, expected_key):
self.assertEqual(rule.op, expected_op)
self.assertEqual(rule.key, expected_key)


class TestSchemaKeyFormat(unittest.TestCase):
"""Dot-notation schema keys support nested property paths."""

@parameterized.expand(
[
("top_level", "Architectures"),
("one_level_nested", "VpcConfig.SecurityGroupIds"),
("two_levels_nested", "VpcConfig.SubnetConfig.SubnetIds"),
]
)
def test_valid_dot_notation_keys(self, _name, key):
"""Dot-separated paths are the schema key format — all valid."""
schema = {key: REPLACE}
# Should not raise — dots are path separators, not errors
self.assertIn(key, schema)
self.assertEqual(schema[key], REPLACE)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Merge strategy translator-level tests.
# Add new test cases here when new rules are added to CUSTOM_STRATEGIES.
Globals:
Function:
Runtime: python3.12
Handler: app.handler
Architectures:
- x86_64

Resources:
FunctionInheritsGlobalArch:
Type: AWS::Serverless::Function
Properties:
CodeUri: s3://bucket/code.zip

FunctionOverridesArch:
Type: AWS::Serverless::Function
Properties:
CodeUri: s3://bucket/code.zip
Architectures:
- arm64
Loading