diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..8d02799370 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -14,7 +14,6 @@ from gt4py.eve import utils from gt4py.next import common -from gt4py.next.iterator import ir as itir from gt4py.next.otf import code_specs, definitions from gt4py.next.otf.binding import interface @@ -36,18 +35,13 @@ def compilation_hash(program_def: definitions.CompilableProgramDef) -> int: def fingerprint_compilable_program(program_def: definitions.CompilableProgramDef) -> str: """ - Generates a unique hash string for a stencil source program representing - the program, sorted offset_provider, and column_axis. + Generates a unique hash string for a compilable program representing + the program IR and all compile-time arguments. """ - program: itir.Program = program_def.data - offset_provider: common.OffsetProvider = program_def.args.offset_provider - column_axis: Optional[common.Dimension] = program_def.args.column_axis - program_hash = utils.content_hash( ( - program.fingerprint(), - sorted(offset_provider.items(), key=lambda el: el[0]), - column_axis, + program_def.data.fingerprint(), + program_def.args, ) ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_ir.py b/tests/next_tests/unit_tests/iterator_tests/test_ir.py index 1041ea6a4f..958c5cea3a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_ir.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_ir.py @@ -9,6 +9,8 @@ import pytest from gt4py.next.iterator import ir +from gt4py.next.otf import arguments, definitions, stages +from gt4py.next.type_system import type_specifications as ts from gt4py import eve @@ -50,3 +52,39 @@ def node_maker(fun: str, filename: str): node3 = node_maker("f3", "loc1") assert node1.fingerprint() == node2.fingerprint() assert node1.fingerprint() != node3.fingerprint() + + +def test_different_precisions(): + program = ir.Program( + id="test_program", + function_definitions=[], + params=[ir.Sym(id="arg")], + declarations=[], + body=[], + ) + + compilable_single = definitions.CompilableProgramDef( + data=program, + args=arguments.CompileTimeArgs( + args=(ts.ScalarType(kind=ts.ScalarKind.FLOAT32),), + kwargs={}, + offset_provider={}, + column_axis=None, + argument_descriptor_contexts={}, + ), + ) + compilable_double = definitions.CompilableProgramDef( + data=program, + args=arguments.CompileTimeArgs( + args=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),), + kwargs={}, + offset_provider={}, + column_axis=None, + argument_descriptor_contexts={}, + ), + ) + + hash_single = stages.fingerprint_compilable_program(compilable_single) + hash_double = stages.fingerprint_compilable_program(compilable_double) + + assert hash_single != hash_double