From 96b7f9f9aa217c93238922388a7f2b177ae43f56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20K=C4=99pie=C5=84?= Date: Mon, 14 Mar 2022 08:59:32 +0100 Subject: [PATCH] Refactor "statschannel" test's helper modules The "statschannel" system test contains two Python helper modules: - generic.py: test functions directly invoked by both tests-json.py and test-xml.py, - helper.py: helper functions invoked by test functions in generic.py. The above logic for splitting helper functions into Python modules prevents selective test skipping from working due to unconditional import statements being present in both helper modules. For example, if dnspython is not available on the test host, tests-json.py imports generic.py, which in turn imports helper.py, which in turn attempts to import various dnspython modules, triggering ImportError exceptions during test initialization. Various decorators used for some tests (like @pytest.mark.dnspython) suggest that such a scenario should be handled gracefully, but that is not the case - modifying the test collection in conftest.py does not prevent pytest from failing due to import errors. Fix by moving helper functions around to achieve a different split: - generic.py: helper functions only relying on the Python standard library, - generic_dnspython.py: helper functions requiring dnspython. Only two tests in tests-{json,xml}.py need dnspython to work (test_traffic_json(), test_traffic_xml()). Since all dnspython-dependent code is now present in generic_dnspython.py, employ pytest.importorskip() in those two tests to ensure they can be selectively skipped when dnspython is not available. Adjust other code to account for the revised Python helper module layout. Remove all occurrences of the @pytest.mark.dnspython decorator (and all associated code) from the "statschannel" system test to prevent confusion. --- bin/tests/system/statschannel/conftest.py | 12 -- bin/tests/system/statschannel/generic.py | 111 ++++++++++-------- .../{helper.py => generic_dnspython.py} | 101 +++++++--------- bin/tests/system/statschannel/tests-json.py | 15 ++- bin/tests/system/statschannel/tests-xml.py | 15 ++- 5 files changed, 119 insertions(+), 135 deletions(-) rename bin/tests/system/statschannel/{helper.py => generic_dnspython.py} (61%) diff --git a/bin/tests/system/statschannel/conftest.py b/bin/tests/system/statschannel/conftest.py index 59a903ca2b..798bee7300 100644 --- a/bin/tests/system/statschannel/conftest.py +++ b/bin/tests/system/statschannel/conftest.py @@ -23,9 +23,6 @@ def pytest_configure(config): config.addinivalue_line( "markers", "xml: mark tests that need xml.etree to function" ) - config.addinivalue_line( - "markers", "dnspython: mark tests that need dnspython to function" - ) def pytest_collection_modifyitems(config, items): @@ -72,15 +69,6 @@ def pytest_collection_modifyitems(config, items): for item in items: if "xml" in item.keywords: item.add_marker(no_xmlstats) - # Test for dnspython module - skip_dnspython = pytest.mark.skip( - reason="need dnspython module to run") - try: - import dns.query # noqa: F401 - except ModuleNotFoundError: - for item in items: - if "dnspython" in item.keywords: - item.add_marker(skip_dnspython) @pytest.fixture diff --git a/bin/tests/system/statschannel/generic.py b/bin/tests/system/statschannel/generic.py index 1c5fda8092..e4cf82cec5 100644 --- a/bin/tests/system/statschannel/generic.py +++ b/bin/tests/system/statschannel/generic.py @@ -9,7 +9,64 @@ # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. -import helper +from datetime import datetime, timedelta +import os + + +# ISO datetime format without msec +fmt = '%Y-%m-%dT%H:%M:%SZ' + +# The constants were taken from BIND 9 source code (lib/dns/zone.c) +max_refresh = timedelta(seconds=2419200) # 4 weeks +max_expires = timedelta(seconds=14515200) # 24 weeks +now = datetime.utcnow().replace(microsecond=0) +dayzero = datetime.utcfromtimestamp(0).replace(microsecond=0) + + +# Generic helper functions +def check_expires(expires, min_time, max_time): + assert expires >= min_time + assert expires <= max_time + + +def check_refresh(refresh, min_time, max_time): + assert refresh >= min_time + assert refresh <= max_time + + +def check_loaded(loaded, expected): + # Sanity check the zone timers values + assert loaded == expected + assert loaded < now + + +def check_zone_timers(loaded, expires, refresh, loaded_exp): + # Sanity checks the zone timers values + if expires is not None: + check_expires(expires, now, now + max_expires) + if refresh is not None: + check_refresh(refresh, now, now + max_refresh) + check_loaded(loaded, loaded_exp) + + +# +# The output is gibberish, but at least make sure it does not crash. +# +def check_manykeys(name, zone=None): + # pylint: disable=unused-argument + assert name == "manykeys" + + +def zone_mtime(zonedir, name): + + try: + si = os.stat(os.path.join(zonedir, "{}.db".format(name))) + except FileNotFoundError: + return dayzero + + mtime = datetime.utcfromtimestamp(si.st_mtime).replace(microsecond=0) + + return mtime def test_zone_timers_primary(fetch_zones, load_timers, **kwargs): @@ -22,8 +79,8 @@ def test_zone_timers_primary(fetch_zones, load_timers, **kwargs): for zone in zones: (name, loaded, expires, refresh) = load_timers(zone, True) - mtime = helper.zone_mtime(zonedir, name) - helper.check_zone_timers(loaded, expires, refresh, mtime) + mtime = zone_mtime(zonedir, name) + check_zone_timers(loaded, expires, refresh, mtime) def test_zone_timers_secondary(fetch_zones, load_timers, **kwargs): @@ -36,8 +93,8 @@ def test_zone_timers_secondary(fetch_zones, load_timers, **kwargs): for zone in zones: (name, loaded, expires, refresh) = load_timers(zone, False) - mtime = helper.zone_mtime(zonedir, name) - helper.check_zone_timers(loaded, expires, refresh, mtime) + mtime = zone_mtime(zonedir, name) + check_zone_timers(loaded, expires, refresh, mtime) def test_zone_with_many_keys(fetch_zones, load_zone, **kwargs): @@ -50,46 +107,4 @@ def test_zone_with_many_keys(fetch_zones, load_zone, **kwargs): for zone in zones: name = load_zone(zone) if name == 'manykeys': - helper.check_manykeys(name) - - -def test_traffic(fetch_traffic, **kwargs): - - statsip = kwargs['statsip'] - statsport = kwargs['statsport'] - port = kwargs['port'] - - data = fetch_traffic(statsip, statsport) - exp = helper.create_expected(data) - - msg = helper.create_msg("short.example.", "TXT") - helper.update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg) - ans = helper.udp_query(statsip, port, msg) - helper.update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans) - data = fetch_traffic(statsip, statsport) - - helper.check_traffic(data, exp) - - msg = helper.create_msg("long.example.", "TXT") - helper.update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg) - ans = helper.udp_query(statsip, port, msg) - helper.update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans) - data = fetch_traffic(statsip, statsport) - - helper.check_traffic(data, exp) - - msg = helper.create_msg("short.example.", "TXT") - helper.update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg) - ans = helper.tcp_query(statsip, port, msg) - helper.update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans) - data = fetch_traffic(statsip, statsport) - - helper.check_traffic(data, exp) - - msg = helper.create_msg("long.example.", "TXT") - helper.update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg) - ans = helper.tcp_query(statsip, port, msg) - helper.update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans) - data = fetch_traffic(statsip, statsport) - - helper.check_traffic(data, exp) + check_manykeys(name) diff --git a/bin/tests/system/statschannel/helper.py b/bin/tests/system/statschannel/generic_dnspython.py similarity index 61% rename from bin/tests/system/statschannel/helper.py rename to bin/tests/system/statschannel/generic_dnspython.py index 0a44333e14..88dabbca08 100644 --- a/bin/tests/system/statschannel/helper.py +++ b/bin/tests/system/statschannel/generic_dnspython.py @@ -9,75 +9,16 @@ # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. -import os -import os.path - from collections import defaultdict -from datetime import datetime, timedelta import dns.message import dns.query import dns.rcode -# ISO datetime format without msec -fmt = '%Y-%m-%dT%H:%M:%SZ' - -# The constants were taken from BIND 9 source code (lib/dns/zone.c) -max_refresh = timedelta(seconds=2419200) # 4 weeks -max_expires = timedelta(seconds=14515200) # 24 weeks -now = datetime.utcnow().replace(microsecond=0) -dayzero = datetime.utcfromtimestamp(0).replace(microsecond=0) - TIMEOUT = 10 -# Generic helper functions -def check_expires(expires, min_time, max_time): - assert expires >= min_time - assert expires <= max_time - - -def check_refresh(refresh, min_time, max_time): - assert refresh >= min_time - assert refresh <= max_time - - -def check_loaded(loaded, expected): - # Sanity check the zone timers values - assert loaded == expected - assert loaded < now - - -def check_zone_timers(loaded, expires, refresh, loaded_exp): - # Sanity checks the zone timers values - if expires is not None: - check_expires(expires, now, now + max_expires) - if refresh is not None: - check_refresh(refresh, now, now + max_refresh) - check_loaded(loaded, loaded_exp) - - -# -# The output is gibberish, but at least make sure it does not crash. -# -def check_manykeys(name, zone=None): - # pylint: disable=unused-argument - assert name == "manykeys" - - -def zone_mtime(zonedir, name): - - try: - si = os.stat(os.path.join(zonedir, "{}.db".format(name))) - except FileNotFoundError: - return dayzero - - mtime = datetime.utcfromtimestamp(si.st_mtime).replace(microsecond=0) - - return mtime - - def create_msg(qname, qtype): msg = dns.message.make_query(qname, qtype, want_dnssec=True, use_edns=0, payload=4096) @@ -144,3 +85,45 @@ def check_traffic(data, expected): assert len(expected) == len(ordered_expected) assert ordered_data == ordered_expected + + +def test_traffic(fetch_traffic, **kwargs): + + statsip = kwargs['statsip'] + statsport = kwargs['statsport'] + port = kwargs['port'] + + data = fetch_traffic(statsip, statsport) + exp = create_expected(data) + + msg = create_msg("short.example.", "TXT") + update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg) + ans = udp_query(statsip, port, msg) + update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans) + data = fetch_traffic(statsip, statsport) + + check_traffic(data, exp) + + msg = create_msg("long.example.", "TXT") + update_expected(exp, "dns-udp-requests-sizes-received-ipv4", msg) + ans = udp_query(statsip, port, msg) + update_expected(exp, "dns-udp-responses-sizes-sent-ipv4", ans) + data = fetch_traffic(statsip, statsport) + + check_traffic(data, exp) + + msg = create_msg("short.example.", "TXT") + update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg) + ans = tcp_query(statsip, port, msg) + update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans) + data = fetch_traffic(statsip, statsport) + + check_traffic(data, exp) + + msg = create_msg("long.example.", "TXT") + update_expected(exp, "dns-tcp-requests-sizes-received-ipv4", msg) + ans = tcp_query(statsip, port, msg) + update_expected(exp, "dns-tcp-responses-sizes-sent-ipv4", ans) + data = fetch_traffic(statsip, statsport) + + check_traffic(data, exp) diff --git a/bin/tests/system/statschannel/tests-json.py b/bin/tests/system/statschannel/tests-json.py index 6af335e1d2..1f4d6d6e65 100755 --- a/bin/tests/system/statschannel/tests-json.py +++ b/bin/tests/system/statschannel/tests-json.py @@ -19,7 +19,6 @@ import pytest import requests import generic -from helper import fmt # JSON helper functions @@ -50,7 +49,7 @@ def load_timers_json(zone, primary=True): # Check if the primary zone timer exists assert 'loaded' in zone - loaded = datetime.strptime(zone['loaded'], fmt) + loaded = datetime.strptime(zone['loaded'], generic.fmt) if primary: # Check if the secondary zone timers does not exist @@ -61,8 +60,8 @@ def load_timers_json(zone, primary=True): else: assert 'expires' in zone assert 'refresh' in zone - expires = datetime.strptime(zone['expires'], fmt) - refresh = datetime.strptime(zone['refresh'], fmt) + expires = datetime.strptime(zone['expires'], generic.fmt) + refresh = datetime.strptime(zone['refresh'], generic.fmt) return (name, loaded, expires, refresh) @@ -104,10 +103,10 @@ def test_zone_with_many_keys_json(statsport): @pytest.mark.json @pytest.mark.requests -@pytest.mark.dnspython @pytest.mark.skipif(os.getenv("HAVEJSONSTATS", "unset") != "1", reason="JSON not configured") def test_traffic_json(named_port, statsport): - generic.test_traffic(fetch_traffic_json, - statsip="10.53.0.2", statsport=statsport, - port=named_port) + generic_dnspython = pytest.importorskip('generic_dnspython') + generic_dnspython.test_traffic(fetch_traffic_json, + statsip="10.53.0.2", statsport=statsport, + port=named_port) diff --git a/bin/tests/system/statschannel/tests-xml.py b/bin/tests/system/statschannel/tests-xml.py index 0dd3b6b075..efd66a5b68 100755 --- a/bin/tests/system/statschannel/tests-xml.py +++ b/bin/tests/system/statschannel/tests-xml.py @@ -20,7 +20,6 @@ import pytest import requests import generic -from helper import fmt # XML helper functions @@ -79,7 +78,7 @@ def load_timers_xml(zone, primary=True): loaded_el = zone.find('loaded') assert loaded_el is not None - loaded = datetime.strptime(loaded_el.text, fmt) + loaded = datetime.strptime(loaded_el.text, generic.fmt) expires_el = zone.find('expires') refresh_el = zone.find('refresh') @@ -91,8 +90,8 @@ def load_timers_xml(zone, primary=True): else: assert expires_el is not None assert refresh_el is not None - expires = datetime.strptime(expires_el.text, fmt) - refresh = datetime.strptime(refresh_el.text, fmt) + expires = datetime.strptime(expires_el.text, generic.fmt) + refresh = datetime.strptime(refresh_el.text, generic.fmt) return (name, loaded, expires, refresh) @@ -134,10 +133,10 @@ def test_zone_with_many_keys_xml(statsport): @pytest.mark.xml @pytest.mark.requests -@pytest.mark.dnspython @pytest.mark.skipif(os.getenv("HAVEXMLSTATS", "unset") != "1", reason="XML not configured") def test_traffic_xml(named_port, statsport): - generic.test_traffic(fetch_traffic_xml, - statsip="10.53.0.2", statsport=statsport, - port=named_port) + generic_dnspython = pytest.importorskip('generic_dnspython') + generic_dnspython.test_traffic(fetch_traffic_xml, + statsip="10.53.0.2", statsport=statsport, + port=named_port)