diff --git a/Config.py b/Config.py new file mode 100644 index 000000000..402fb4953 --- /dev/null +++ b/Config.py @@ -0,0 +1,568 @@ +from datetime import datetime +from dateutil import dateutil_parser +import collections +import json +import logging +import pprint + + +"""Idea here being to start with something that is decomposed so it's easier to +make do json in *and* out, differences between configs and config extension. +""" + +#TODO scope logging and handlers better, control verbosity by command line flags +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler()) + + +def parse_bool_from_json(value, attr_name): + if value in ('true', '1', 1, 'yes'): + bool_value = True + elif value in ('false', '0', 0, 'no'): + bool_value = False + elif value in (True, False): + bool_value = value + else: + raise ConfigError('Config value %s is an invalid boolean value.' % attr_name) + return bool_value + + +def parse_timestamp(value, attr_name): + if isinstance(value, datetime): + return value + try: + ts = int(value) + return datetime.fromtimestamp(ts) + except (TypeError, ValueError): + pass + try: + return dateutil_parser.parse(value) + except (TypeError, ValueError): + raise ConfigError('Config value %s is an invalid date or timestamp.' % attr_name) + + +def verify_member_of(value, member_list, attr_name): + if value not in member_list: + raise ConfigError('Config value "%s" must be one of (%s)' % ( + attr_name, ', '.join(member_list)) + ) + return value + + +def verify_string(value, attr_name, max_length=200): + if not isinstance(value, (str, unicode)): + raise ConfigError('Config value %s must be a string.' % attr_name) + if len(value) > max_length: + raise ConfigError('Config value %s is too long.' % attr_name) + return value + + +def to_dict(config_dict): + """Cleans up BaseConfig children to be serialized.""" + d = {} + for key, val in config_dict.iteritems(): + if isinstance(val, BaseConfig): + d[key] = to_dict(val._data) + elif isinstance(val, datetime): + d[key] = val.strftime('%Y-%m-%dT%H:%M:%S%z') + elif isinstance(val, dict): + d[key] = to_dict(val) + else: + d[key] = val + return d + + +class BaseConfig(object): + """Top level config class for common methods. + + Requirements for using class: + - list all properties with getters *and* setters in class + variable 'config_properties' + - __init__ of child classes must be callable with *only* + keyword arguments to allow method calls to update to create + a new config + ... more ... + """ + + config_properties = [] + + def __init__(self): + # container for validated properties with JSON names + self._data = {} + + def __repr__(self): + s = '< %s %s >' % (self.__class__.__name__, + pprint.pformat(self._data)) + return s + + def update(self, newer_config, merge=False, **kwargs): + """Create a fresh config combining the new and old configs. + + It does this by iterating over the 'config_properties' class + attribute which contains names of property attributes for the config. + + Two methods of combining configs are possible, an 'update' and + a 'merge', the latter set by the keyword argument 'merge=True'. + + An update overrides older values with new values -- even if those + new values are None. Update will remove values that are present in + the old config if they are not present in the new config. + + A merge by comparison will allow old values to persist if they are + not specified in the new config. This can be used for end-user + customizations to override specific settings without having to re-create + large portions of a config to override it. + + Arguments: + newer_config: A config object to combine with the current config. + merge: Allows old values not overridden to survive into the fresh config. + + Returns: + A config object of the same sort as called upon. + """ + # removed 'merge' kw arg - and it was passed to constructor + # make a note to not do that, consume it on the param list + fresh_config = self.__class__(**kwargs) + logger.debug('from parent update kwargs %s' % kwargs) + logger.debug('from parent update merge %s' % merge) + if not isinstance(newer_config, self.__class__): + raise ConfigError('Attempting to update a %s with a %s' % ( + self.__class__, + newer_config.__class__)) + for prop_name in self.config_properties: + # get the specified property off of the current class + prop = self.__class__.__dict__.get(prop_name) + assert prop + new_value = prop.fget(newer_config) + old_value = prop.fget(self) + if new_value is not None: + prop.fset(fresh_config, new_value) + elif merge and old_value is not None: + prop.fset(fresh_config, old_value) + return fresh_config + + def merge(self, newer_config, **kwargs): + """Combines configs and keeps old values if they are not overridden. + + See docstring for 'update' method for more details. + + Arguments: + newer_config: A config object to combine with the current config. + merge: Allows old values not overridden to survive into the fresh config. + + Returns: + A config object of the same sort as called upon. + """ + kwargs['merge'] = True + logger.debug('from parent merge: %s' % kwargs) + return self.update(newer_config, **kwargs) + + def to_json(self): + d = to_dict(self._data) + return json.dumps(d) + + def write_to_json_file(self, json_filename, f_open=open): + data = self.to_json() + try: + with f_open(json_filename, 'w') as f: + f.write(data) + except IOError: + raise + + def load_from_json_file(self, json_filename, f_open=open): + try: + with f_open(json_filename, 'r') as f: + json_str = f.read() + json_dict = json.loads(json_str) + except IOError: + raise + except ValueError: + raise ConfigError('No valid JSON found in file: %s' % json_filename) + self.from_json_dict(json_dict) + + def from_json_dict(self, json_dict): + raise NotImplmented('BaseConfig should not be populated.') + + +class Config(BaseConfig): + """Config container for StartTLS Everywhere configuration. + + Intended as a simple container that unifies where validatation occurs, + and is capable of comparing configs to warn of things like changing + certificate fingerprints from one scan to the next. + + There is a one to one mapping of the object attributes to the JSON + object keys, albeit with dashes replaced with underscores. + """ + + def __init__(self): + super(self.__class__, self).__init__() + self._data['tls-policies'] = {} + self._data['acceptable-mxs'] = {} + + def __add__(self, other_config): + """Allow addition but not really of *full* configs, need to flesh that out.""" + #TODO add this + raise NotImplemented + + def update(self, other_config): + """Update properties of config from a 'newer' config and force verification.""" + #TODO add this + new_config = Config() + raise NotImplemented + + def from_json_dict(self, json_dict): + """Assign JSON data to Config properties and declare sub-objects. + + Let's property verification methods do the heavy lifting and mostly + maps between the JSON config names and attributes. Keeps track of + unused variables and warns about them. + """ + for key, val in json_dict.iteritems(): + if key == 'author': + self.author = val + elif key == 'comment': + self.comment = val + elif key == 'expires': + self.expires = val + elif key == 'timestamp': + self.timestamp = val + elif key == 'tls-policies': + self.make_tls_policy_dict(val) + elif key == 'acceptable-mxs': + self.make_acceptable_mxs_dict(val) + else: + logger.warn('Unknown attribute "%s", skipping' % key) + + @property + def author(self): + return self._data.get('author') + + @author.setter + def author(self, value): + self._data['author'] = verify_string(value, 'author') + + @property + def comment(self): + return self._data.get('comment') + + @comment.setter + def comment(self, value): + self._data['comment'] = verify_string(value, 'comment') + + @property + def expires(self): + return self._data.get('expires') + + @expires.setter + def expires(self, value): + self._data['expires'] = parse_timestamp(value, 'expires') + + @property + def timestamp(self): + return self._data.get('timestamp') + + @timestamp.setter + def timestamp(self, value): + self._data['timestamp'] = parse_timestamp(value, 'timestamp') + + @property + def tls_policies(self): + return self._data.get('tls-policies') + + @property + def acceptable_mxs(self): + return self._data.get('acceptable-mxs') + + def make_tls_policy_dict(self, policy_dict): + tls_policy_dict = self.tls_policies + for domain_suffix, settings in policy_dict.iteritems(): + new_domain_policy = TLSPolicy(domain_suffix) + try: + new_domain_policy.from_json_dict(settings) + except ConfigError as e: + raise + tls_policy_dict[domain_suffix] = new_domain_policy + + def get_tls_policy(self, mx_domain): + return self.tls_policies.get(mx_domain) + + def make_acceptable_mxs_dict(self, mxs_dict): + acceptable_mxs_dict = self._data['acceptable-mxs'] + for domain, settings in mxs_dict.iteritems(): + new_domain_policy = AcceptableMX(domain) + try: + new_domain_policy.from_json_dict(settings) + except ConfigError as e: + raise + acceptable_mxs_dict[domain] = new_domain_policy + + def get_address_domains(self, mx_hostname, mx_to_domain_map): + """Do a fuzzy DNS host match on provided map to get lists of policies. + + Args: + mx_hostname (string): The hostname from an MX record. + mx_to_domain_map: Mapping from MX hosts to AcceptableMX + policies, the same AcceptableMX policy may occur more + than once. e.g. {'mx_host3': set(AcceptableMX, ...)} + The map can be generated by Config.get_mx_to_domain_policy_map. + + Returns: + The set containing all AcceptableMX policies that list the + provided MX host as viable. + """ + labels = mx_hostname.split(".") + for n in range(1, len(labels)): + parent = "." + ".".join(labels[n:]) + if parent in mx_to_domain_map: + return mx_to_domain_map[parent] + return None + + def get_mx_to_domain_policy_map(self): + """Create mapping of MX hostnames to sets of AcceptableMX policies. + + Generate a dictionary that is typically used in log analysis + (e.g. if your MTA logs interact with beta.innotech.com you use + this mapping to tell you it used the innotech.com AcceptableMX + policy or policies). There are of course complications. + """ + # create reverse mapping dictionary as well for auditing + # and reviewing logs + mx_to_domain_policy = collections.defaultdict(set) + + for mx_host, domain_policy in self.get_all_mx_items(): + existing_mx_policies = mx_to_domain_policy.get(mx_host) + if existing_mx_policies: + existing_domains = [ e.domain for e in existing_mx_policies ] + if domain_policy.domain not in existing_domains: + #TODO plenty of room to enforce a security policy here + # this is also the case of google apps personal domains + msg = ('Attempting to add domain policy (%s) for MX host but MX' + ' host already has a domain policy (%s), appending...') + logger.debug(msg % (domain_policy.domain, + ', '.join(existing_domains))) + mx_to_domain_policy[mx_host].add(domain_policy) + return mx_to_domain_policy + + def get_all_mx_items(self): + """Iterate over (mx_host, mx_policy) - be sure to dedup!""" + all_mx_items = [] + for policy in self.acceptable_mxs.values(): + accepted_mxs = policy.accept_mx_domains + all_mx_items.extend([(mx_host, policy) + for mx_host in accepted_mxs]) + return all_mx_items + + def get_all_mx_hosts(self): + all_mx_hosts = [] + [ all_mx_hosts.extend(domain_policy.acceptable_mxs) + for domain_policy in self.acceptable_mxs.values() ] + return all_mx_hosts + + def is_valid(self): + #TODO implement checks to make sure domains don't overlap + #TODO add debug logging for troubleshooting sake + for mx_config in self.acceptable_mxs.values(): + if not mx_config.is_valid(): + return False + for domain_suffix in mx_config.accept_mx_domains: + # check to make sure every accepted MX has a TLS policy + if not domain_suffix in self.tls_policies: + return False + all_mx_hosts = self.get_all_mx_hosts() + for domain_suffix, tls_config in self.tls_policies.iteritems(): + if not tls_config.is_valid(): + return False + # make sure no unclaimed TLS policies have made their way in + if domain_suffix not in all_mx_hosts: + return False + return True + + +class TLSPolicy(BaseConfig): + + ENFORCE_MODES = ('enforce', 'log-only') + TLS_VERSIONS = ('TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3') + + config_properties = ['comment', 'enforce_mode', 'min_tls_version', + 'require_tls', 'require_valid_certificate'] + + def __init__(self, domain_suffix=None): + super(self.__class__, self).__init__() + self.domain_suffix = domain_suffix + #TODO add support for two designed but yet unsupported attrs + # self._data['accept-spki-hashs'] = None + # self._data['error-notification'] = None + + def from_json_dict(self, json_dict): + for key, val in json_dict.iteritems(): + if key == 'comment': + self.comment = val + elif key == 'enforce-mode': + self.enforce_mode = val + elif key == 'min-tls-version': + self.min_tls_version = val + elif key == 'require-tls': + self.require_tls = val + elif key == 'require-valid-certificate': + self.require_valid_certificate = val + else: + logger.warn('Unknown key %s' % key) + + def is_valid(self): + """Do simple check that config contains all required values. + + Should find a way to expose easily which config values + are required, at least place in error messages such that + incomplete configs will expose it. + """ + required_attrs = ('enforce-mode', 'min-tls-version', + 'require-tls') + values_set = [self._data.get(attr) for attr in required_attrs] + if not all(values_set): + return False + else: + return True + + def update(self, newer_policy, **kwargs): + if not kwargs.get('domain_suffix'): + kwargs['domain_suffix'] = self.domain_suffix + fresh_policy = super(self.__class__, self).update(newer_policy, + **kwargs) + logger.debug('from TLS child update %s' % kwargs) + return fresh_policy + + def merge(self, newer_policy, **kwargs): + logger.debug('from TLS child merge: %s' % kwargs) + fresh_policy = super(self.__class__, self).merge(newer_policy, + domain_suffix=self.domain_suffix) + return fresh_policy + + @property + def comment(self): + return self._data.get('comment') + + @comment.setter + def comment(self, value): + self._data['comment'] = verify_string(value, 'comment') + + @property + def enforce_mode(self): + return self._data.get('enforce-mode') + + @enforce_mode.setter + def enforce_mode(self, value): + self._data['enforce-mode'] = verify_member_of(value, self.ENFORCE_MODES, 'enforce-mode') + + @property + def min_tls_version(self): + return self._data.get('min-tls-version') + + @min_tls_version.setter + def min_tls_version(self, value): + """TODO: Should this be dealing only with strings processed by map ... lower()?""" + tls_versions = [ver.lower() for ver in self.TLS_VERSIONS] + tls_versions.extend(self.TLS_VERSIONS) + self._data['min-tls-version'] = verify_member_of(value, tls_versions, 'min-tls-version') + + @property + def require_tls(self): + return self._data.get('require-tls') + + @require_tls.setter + def require_tls(self, value): + self._data['require-tls'] = parse_bool_from_json(value, 'require-tls') + + @property + def require_valid_certificate(self): + return self._data.get('require-valid-certificate') + + @require_valid_certificate.setter + def require_valid_certificate(self, value): + self._data['require-valid-certificate'] = parse_bool_from_json(value, 'require-valid-certificate') + + +class AcceptableMX(BaseConfig): + """Holds acceptable MX domain suffixes for a single mail serving domain. + + Such as for gmail.com that single mail serving suffix domain is: + gmail-smtp-in.l.google.com. + + Configuration of the acceptable MX suffix domains must match up with TLS policies + for the suffix domains. + """ + def __init__(self, domain=None): + super(self.__class__, self).__init__() + self.domain = domain + self._data['accept-mx-domains'] = [] + + @property + def accept_mx_domains(self): + return self._data.get('accept-mx-domains') + + def add_acceptable_mx(self, domain_suffix): + unique_domain_suffixes = set(self._data['accept-mx-domains']) + unique_domain_suffixes.add(domain_suffix) + self._data['accept-mx-domains'] = list(unique_domain_suffixes) + + @property + def comment(self): + return self._data.get('comment') + + @comment.setter + def comment(self, value): + self._data['comment'] = verify_string(value, 'comment') + + def is_valid(self): + """Check to make sure there is one acceptable domain suffix. + + This will need to be updated once we can actually test and support + for more than one acceptable domain suffix. + + TODO: could make this object double check the data it is given with + DNS queries. + """ + if len(self._data['accept-mx-domains']) != 1: + return False + else: + return True + + def from_json_dict(self, json_dict): + for key, val in json_dict.iteritems(): + if key == 'accept-mx-domains': + if isinstance(val, list): + for domain_suffix in val: + self.add_acceptable_mx(domain_suffix) + else: + self.add_acceptable_mx(val) + elif key == 'comment': + self.comment = val + else: + logger.warn('warning: unknown key %s' % key) + + def update(self, newer_policy, **kwargs): + logger.debug('from MX child update got %s' % kwargs) + if not kwargs.get('domain'): + kwargs['domain'] = self.domain + fresh_policy = super(self.__class__, self).update(newer_policy, + **kwargs) + if kwargs.get('merge'): + new_accepted_mxs = set(self.accept_mx_domains) + new_accepted_mxs = new_accepted_mxs.union(newer_policy.accept_mx_domains) + else: + new_accepted_mxs = newer_policy.accept_mx_domains + for domain in new_accepted_mxs: + fresh_policy.add_acceptable_mx(domain) + + return fresh_policy + + def merge(self, newer_policy, **kwargs): + logger.debug('from MX child merge: %s' % kwargs) + fresh_policy = super(self.__class__, self).merge(newer_policy, + **kwargs) + return fresh_policy + + +class ConfigError(ValueError): + def __init__(self, message): + super(self.__class__, self).__init__(message) diff --git a/MTAConfigGenerator.py b/MTAConfigGenerator.py index a733ab27a..1c273cf94 100755 --- a/MTAConfigGenerator.py +++ b/MTAConfigGenerator.py @@ -33,6 +33,7 @@ class PostfixConfigGenerator(MTAConfigGenerator): self.postfix_cf_file = self.find_postfix_cf() self.wrangle_existing_config() self.set_domainwise_tls_policies() + #TODO make this optional for testing, etc. os.system("sudo service postfix reload") def ensure_cf_var(self, var, ideal, also_acceptable): @@ -120,33 +121,37 @@ class PostfixConfigGenerator(MTAConfigGenerator): def set_domainwise_tls_policies(self): self.policy_lines = [] - for address_domain, properties in self.policy_config.acceptable_mxs.items(): - mx_list = properties["accept-mx-domains"] + all_acceptable_mxs = self.policy_config.get_acceptable_mxs_dict() + for address_domain, properties in all_acceptable_mxs.items(): + mx_list = properties.accept_mx_domains if len(mx_list) > 1: - print "Lists of multiple accept-mx-domains not yet supported, skipping ", address_domain + print "Lists of multiple accept-mx-domains not yet supported." + print "Using MX %s for %s" % (mx_list[0], address_domain) + print "Ignoring: %s" % (', '.join(mx_list[1:])) mx_domain = mx_list[0] - mx_policy = self.policy_config.tls_policies[mx_domain] + mx_policy = self.policy_config.get_tls_policy(mx_domain) entry = address_domain + " encrypt" - if "min-tls-version" in mx_policy: - if mx_policy["min-tls-version"].lower() == "tlsv1": - entry += " protocols=!SSLv2,!SSLv3" - elif mx_policy["min-tls-version"].lower() == "tlsv1.1": - entry += " protocols=!SSLv2,!SSLv3,!TLSv1" - elif mx_policy["min-tls-version"].lower() == "tlsv1.2": - entry += " protocols=!SSLv2,!SSLv3,!TLSv1,!TLSv1.1" - else: - print mx_policy["min-tls-version"] + if mx_policy.min_tls_version.lower() == "tlsv1": + entry += " protocols=!SSLv2,!SSLv3" + elif mx_policy.min_tls_version.lower() == "tlsv1.1": + entry += " protocols=!SSLv2,!SSLv3,!TLSv1" + elif mx_policy.min_tls_version.lower() == "tlsv1.2": + entry += " protocols=!SSLv2,!SSLv3,!TLSv1,!TLSv1.1" + else: + print mx_policy.min_tls_version self.policy_lines.append(entry) f = open(self.policy_file, "w") f.write("\n".join(self.policy_lines) + "\n") f.close() + if __name__ == "__main__": - import ConfigParser + import Config as config if len(sys.argv) != 3: print "Usage: MTAConfigGenerator starttls-everywhere.json /etc/postfix" sys.exit(1) - c = ConfigParser.Config(sys.argv[1]) + c = config.Config() + c.load_from_json_file(sys.argv[1]) postfix_dir = sys.argv[2] pcgen = PostfixConfigGenerator(c, postfix_dir, fixup=True) diff --git a/PostfixLogSummary.py b/PostfixLogSummary.py index 0348432b0..f9e717f66 100755 --- a/PostfixLogSummary.py +++ b/PostfixLogSummary.py @@ -3,7 +3,7 @@ import re import sys import collections -import ConfigParser +import Config # TODO: There's more to be learned from postfix logs! Here's one sample # observed during failures from the sender vagrant vm: @@ -35,21 +35,24 @@ def get_counts(input, config): # Log lines for when a TLS connection was successfully established. These can # indicate the difference between Untrusted, Trusted, and Verified certs. connected_re = re.compile("([A-Za-z]+) TLS connection established to ([^[]*)") + mx_to_domain_mapping = config.get_mx_to_domain_policy_map() + for line in sys.stdin: deferred = deferred_re.search(line) connected = connected_re.search(line) if connected: - validation = result.group(1) - mx_hostname = result.group(2).lower() + validation = connected.group(1) + mx_hostname = connected.group(2).lower() if validation == "Trusted" or validation == "Verified": seen_trusted = True - address_domains = config.get_address_domains(mx_hostname) + address_domains = config.get_address_domains(mx_hostname, mx_to_domain_mapping) if address_domains: - for d in address_domains: - counts[d][validation] += 1 - counts[d]["all"] += 1 + domains_str = [ a.domain for a in address_domains ] + d = ', '.join(domains_str) + counts[d][validation] += 1 + counts[d]["all"] += 1 elif deferred: - mx_hostname = result.group(1).lower() + mx_hostname = deferred.group(1).lower() tls_deferred[mx_hostname] += 1 if not seen_trusted: # Postfix will only emit 'Trusted' if the certificate validates according to @@ -65,7 +68,11 @@ def print_summary(counts): print mx_hostname, validation, validation_count / validations["all"], "of", validations["all"] if __name__ == "__main__": - config = ConfigParser.Config("starttls-everywhere.json") + if len(sys.argv) != 2: + print "Usage: %s starttls-everywhere.json" % sys.argv[0] + sys.exit(1) + config = Config.Config() + config.load_from_json_file(sys.argv[1]) (counts, tls_deferred) = get_counts(sys.stdin, config) print_summary(counts) print tls_deferred diff --git a/TestConfig.py b/TestConfig.py new file mode 100644 index 000000000..f323554c0 --- /dev/null +++ b/TestConfig.py @@ -0,0 +1,131 @@ +import copy +import itertools +import logging +import unittest + +import Config + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler()) + + +class TestTLSPolicy(unittest.TestCase): + + def setUp(self): + self.old_config = Config.TLSPolicy(domain_suffix='.eff.org') + self.old_config.comment = 'Testing EFF.org TLS policy' + self.old_config.require_tls = True + self.old_config.require_valid_certificate = False + self.old_config.min_tls_version = 'TLSv1' + self.old_config.enforce_mode = 'log-only' + + self.new_config = Config.TLSPolicy(domain_suffix='.eff.org') + self.new_config.require_valid_certificate = True + self.new_config.min_tls_version = 'TLSv1.2' + self.new_config.enforce_mode = 'enforce' + + def testUpdateDropsOldSettings(self): + logger.debug('old: %s' % self.old_config) + logger.debug('new: %s' % self.new_config) + tls_policy = self.old_config.update(self.new_config) + logger.debug('just generated: %s' % tls_policy) + self.assertFalse(any([tls_policy.require_tls, tls_policy.comment])) + + def testMergeKeepsOldSettings(self): + logger.debug('old: %s' % self.old_config) + logger.debug('new: %s' % self.new_config) + tls_policy = self.old_config.merge(self.new_config, merge=True) + logger.debug('just generated: %s' % tls_policy) + self.assertTrue(all([tls_policy.require_tls, tls_policy.comment])) + + def testUpdateGetsNameSet(self): + tls_policy = self.old_config.update(self.new_config) + self.assertEquals(tls_policy.domain_suffix, self.old_config.domain_suffix) + + +class TestAcceptableMX(unittest.TestCase): + + def setUp(self): + self.old_config = Config.AcceptableMX(domain='eff.org') + self.old_config.add_acceptable_mx('.eff.org') + + def testUpdateDropsOldMXs(self): + new_bogus_mx = '.testing.eff.org' + new_config = Config.AcceptableMX(domain='eff.org') + new_config.add_acceptable_mx(new_bogus_mx) + updated_config = self.old_config.update(new_config) + self.assertNotIn('.eff.org', updated_config.accept_mx_domains) + + def testMergeKeepsOldMXs(self): + new_bogus_mx = '.testing.eff.org' + new_config = Config.AcceptableMX(domain='eff.org') + new_config.add_acceptable_mx(new_bogus_mx) + updated_config = self.old_config.merge(new_config) + self.assertListEqual(sorted(['.eff.org', '.testing.eff.org']), + sorted(updated_config.accept_mx_domains)) + + def testUpdateGetsNameSet(self): + new_policy = Config.AcceptableMX(domain=self.old_config.domain) + mx_policy = self.old_config.update(new_policy) + self.assertEquals(mx_policy.domain, self.old_config.domain) + + +class TestConfig(unittest.TestCase): + """Test entire configuration. + + Currently lower coverage is being obtained since string sets are + being compared rather than returned objects. Comparison logic for + the config objects isn't clear yet and proof that they function is enough. + """ + + def setUp(self): + self.config = Config.Config() + domain_policies = self.config._data['acceptable-mxs'] + self.mail_domains = ['gmail.com', 'yahoo.com', 'hotmail.com', '123.cn', 'qq.com'] + for domain in self.mail_domains: + new = Config.AcceptableMX(domain=domain) + new.add_acceptable_mx('.' + domain) + domain_policies[domain] = new + + def testGetAllMxItems(self): + """Make sure the basic use case of get_all_mx_items functions.""" + # [ ('.gmail.com', 'gmail.com'), ('.yahoo.com', 'yahoo.com'), ... ] + control_data = [ ('.' + domain, domain) for domain in self.mail_domains ] + test_data = [ (mx, p.domain) for mx, p in self.config.get_all_mx_items() ] + self.assertListEqual(sorted(test_data), sorted(control_data)) + + def testGetAllMxItemsMultiMX(self): + config = copy.deepcopy(self.config) + domain_policy = config.acceptable_mxs.get('gmail.com') + # deal with reality, mail.google.com + domain_policy.add_acceptable_mx('.mail.google.com') + control_data = [ ('.' + domain, domain) for domain in self.mail_domains ] + control_data.append(('.mail.google.com', 'gmail.com')) + test_data = [ (mx, p.domain) for mx, p in config.get_all_mx_items() ] + self.assertListEqual(sorted(test_data), sorted(control_data)) + + def testGetMXtoDomainPolicy(self): + control_data = dict([ ('.' + domain, set([domain])) + for domain in self.mail_domains ]) + test_data = {} + for mx, pset in self.config.get_mx_to_domain_policy_map().items(): + policy_list = [ p.domain for p in pset ] + test_data[mx] = set(policy_list) + self.assertDictEqual(test_data, control_data) + + def testGetMXtoDomainPolicyMultiMX(self): + config = copy.deepcopy(self.config) + domain_policy = config.acceptable_mxs.get('gmail.com') + domain_policy.add_acceptable_mx('.mail.google.com') + control_data = dict([ ('.' + domain, set([domain])) + for domain in self.mail_domains ]) + control_data['.mail.google.com'] = set(['gmail.com']) + test_data = {} + for mx, pset in config.get_mx_to_domain_policy_map().items(): + policy_list = [ p.domain for p in pset ] + test_data[mx] = set(policy_list) + self.assertDictEqual(test_data, control_data) + + +if __name__ == '__main__': + unittest.main() diff --git a/bigger_test_config.json b/bigger_test_config.json new file mode 100644 index 000000000..c3c23c455 --- /dev/null +++ b/bigger_test_config.json @@ -0,0 +1,36 @@ +{ + "timestamp": 1401414363, + "author": "Electronic Frontier Foundation https://eff.org", + "expires": "2015-08-01T12:00:00+08:00", + "tls-policies": { + ".yahoodns.net": { + "require-valid-certificate": true + }, + ".eff.org": { + "require-tls": true, + "min-tls-version": "TLSv1.1", + "enforce-mode": "enforce", + "accept-spki-hashes": [ + "sha1/5R0zeLx7EWRxqw6HRlgCRxNLHDo=", + "sha1/YlrkMlC6C4SJRZSVyRvnvoJ+8eM=" + ] + }, + ".google.com": { + "require-valid-certificate": true, + "min-tls-version": "TLSv1.1", + "enforce-mode": "log-only", + "error-notification": "https://google.com/post/reports/here" + } + }, + "acceptable-mxs": { + "yahoo.com": { + "accept-mx-domains": [".yahoodns.net"] + }, + "gmail.com": { + "accept-mx-domains": [".google.com"] + }, + "eff.org": { + "accept-mx-domains": [".eff.org"] + } + } +}