diff --git a/letsencrypt-apache/letsencrypt_apache/configurator.py b/letsencrypt-apache/letsencrypt_apache/configurator.py index 9e9d606c2..7b1144e2e 100644 --- a/letsencrypt-apache/letsencrypt_apache/configurator.py +++ b/letsencrypt-apache/letsencrypt_apache/configurator.py @@ -15,6 +15,7 @@ from acme import challenges from letsencrypt import achallenges from letsencrypt import errors from letsencrypt import interfaces +from letsencrypt import le_util from letsencrypt_apache import augeas_configurator from letsencrypt_apache import constants @@ -92,17 +93,21 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): @classmethod def add_parser_arguments(cls, add): - add("server-root", default=constants.CLI_DEFAULTS["server_root"], - help="Apache server root directory.") add("ctl", default=constants.CLI_DEFAULTS["ctl"], help="Path to the 'apache2ctl' binary, used for 'configtest', " "retrieving the Apache2 version number, and initialization " "parameters.") + add("enmod", default=constants.CLI_DEFAULTS["enmod"], + help="Path to the Apache 'a2enmod' binary.") + add("dismod", default=constants.CLI_DEFAULTS["dismod"], + help="Path to the Apache 'a2enmod' binary.") add("init-script", default=constants.CLI_DEFAULTS["init_script"], help="Path to the Apache init script (used for server " "reload/restart).") add("le-vhost-ext", default=constants.CLI_DEFAULTS["le_vhost_ext"], help="SSL vhost configuration extension.") + add("server-root", default=constants.CLI_DEFAULTS["server_root"], + help="Apache server root directory.") def __init__(self, *args, **kwargs): """Initialize an Apache Configurator. @@ -942,12 +947,13 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): "Unsupported filesystem layout. " "sites-available/enabled expected.") - def enable_mod(self, mod_name): + def enable_mod(self, mod_name, temp=False): """Enables module in Apache. Both enables and restarts Apache so module is active. :param str mod_name: Name of the module to enable. (e.g. 'ssl') + :param bool temp: Whether or not this is a temporary action. """ # Support Debian specific setup @@ -958,7 +964,7 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): "Unsupported directory layout. You may try to enable mod %s " "and try again." % mod_name) - self._enable_mod_debian(mod_name) + self._enable_mod_debian(mod_name, temp) self.save_notes += "Enabled %s module in Apache" % mod_name logger.debug("Enabled Apache %s module", mod_name) @@ -970,39 +976,19 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): self.parser.modules.add(mod_name + "_module") self.parser.modules.add("mod_" + mod_name + ".c") - def _enable_mod_debian(self, mod_name): + def _enable_mod_debian(self, mod_name, temp): """Assumes mods-available, mods-enabled layout.""" - # TODO: This can be further updated to not require all files. - if mod_name == "ssl": - self._enable_mod_debian_files( - ["ssl.conf", "ssl.load"], "ssl_module") - elif mod_name == "rewrite": - self._enable_mod_debian_files(["rewrite.load"], "rewrite_module") - else: - raise errors.NotSupportedError + # Generate reversal command. + # Try to be safe here... check that we can probably reverse before + # applying enmod command + if not le_util.exe_exists(self.conf("dismod")): + raise errors.MisconfigurationError( + "Unable to find a2dismod, please make sure a2enmod and " + "a2dismod are configured correctly for letsencrypt.") - def _enable_mod_debian_files(self, filenames, mod_name): - """Move over all required files into mods-enabled.""" - mods_available = os.path.join(self.parser.root, "mods-available") - mods_enabled = os.path.join(self.parser.root, "mods-enabled") - - # Check to see all files are available. - for filename in filenames: - if not os.path.isfile(os.path.join(mods_available, filename)): - raise errors.NoInstallationError( - "Unable to enable module. Required files missing from " - "mods-available. %s" % str(filenames)) - - # Register and symlink files - for filename in filenames: - enabled_path = os.path.join(mods_enabled, filename) - if os.path.isfile(enabled_path): - logger.debug( - "Error - enabling module %s, filepath already exists " - "%s", mod_name, enabled_path) - raise errors.PluginError("Error enabling module %s" % mod_name) - self.reverter.register_file_creation(False, enabled_path) - os.symlink(os.path.join(mods_available, filename), enabled_path) + self.reverter.register_undo_command( + temp, [self.conf("dismod"), mod_name]) + le_util.run_script([self.conf("enmod"), mod_name]) def restart(self): """Restarts apache server. @@ -1018,25 +1004,13 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): def config_test(self): # pylint: disable=no-self-use """Check the configuration of Apache for errors. - :raises .errors.PluginError: If Unable to run apache2ctl :raises .errors.MisconfigurationError: If config_test fails """ try: - proc = subprocess.Popen( - [self.conf("ctl"), "configtest"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - stdout, stderr = proc.communicate() - except (OSError, ValueError): - logger.fatal("Unable to run /usr/sbin/apache2ctl configtest") - raise errors.PluginError("Unable to run apache2ctl") - - if proc.returncode != 0: - # Enter recovery routine... - logger.error("Apache Configtest failed\n%s\n%s", stdout, stderr) - raise errors.MisconfigurationError( - "Apache Configtest failure:\n%s\n%s" % (stdout, stderr)) + le_util.run_script([self.conf("ctl"), "configtest"]) + except errors.SubprocessError: + raise errors.MisconfigurationError("Config Test failed!") def get_version(self): """Return version of Apache Server. @@ -1050,17 +1024,13 @@ class ApacheConfigurator(augeas_configurator.AugeasConfigurator): """ try: - proc = subprocess.Popen( - [self.conf("ctl"), "-v"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - text = proc.communicate()[0] - except (OSError, ValueError): + stdout, _ = le_util.run_script([self.conf("ctl"), "-v"]) + except errors.SubprocessError: raise errors.PluginError( "Unable to run %s -v" % self.conf("ctl")) regex = re.compile(r"Apache/([0-9\.]*)", re.IGNORECASE) - matches = regex.findall(text) + matches = regex.findall(stdout) if len(matches) != 1: raise errors.PluginError("Unable to find Apache version") diff --git a/letsencrypt-apache/letsencrypt_apache/constants.py b/letsencrypt-apache/letsencrypt_apache/constants.py index 7e7e127f5..b38e898cf 100644 --- a/letsencrypt-apache/letsencrypt_apache/constants.py +++ b/letsencrypt-apache/letsencrypt_apache/constants.py @@ -6,6 +6,7 @@ CLI_DEFAULTS = dict( server_root="/etc/apache2", ctl="apache2ctl", enmod="a2enmod", + dismod="a2dismod", init_script="/etc/init.d/apache2", le_vhost_ext="-le-ssl.conf", ) diff --git a/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py b/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py index d318805a6..46073619a 100644 --- a/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py +++ b/letsencrypt-apache/letsencrypt_apache/tests/configurator_test.py @@ -162,50 +162,45 @@ class TwoVhost80Test(util.ApacheTest): self.assertTrue(self.config.is_site_enabled(self.vh_truth[2].filep)) self.assertTrue(self.config.is_site_enabled(self.vh_truth[3].filep)) + @mock.patch("letsencrypt.le_util.run_script") + @mock.patch("letsencrypt.le_util.exe_exists") @mock.patch("letsencrypt_apache.parser.subprocess.Popen") - def test_enable_mod(self, mock_popen): + def test_enable_mod(self, mock_popen, mock_exe_exists, mock_run_script): mock_popen().communicate.return_value = ("Define: DUMP_RUN_CFG", "") mock_popen().returncode = 0 + mock_exe_exists.return_value = True self.config.enable_mod("ssl") - for filename in ["ssl.conf", "ssl.load"]: - self.assertTrue( - os.path.isfile(os.path.join( - self.config.conf("server-root"), "mods-enabled", filename))) - self.assertTrue("ssl_module" in self.config.parser.modules) self.assertTrue("mod_ssl.c" in self.config.parser.modules) + self.assertTrue(mock_run_script.called) + def test_enable_mod_unsupported_dirs(self): shutil.rmtree(os.path.join(self.config.parser.root, "mods-enabled")) self.assertRaises( errors.NotSupportedError, self.config.enable_mod, "ssl") - def test_enable_mod_unsupported_mod(self): + @mock.patch("letsencrypt.le_util.exe_exists") + def test_enable_mod_no_disable(self, mock_exe_exists): + mock_exe_exists.return_value = False self.assertRaises( - errors.NotSupportedError, self.config.enable_mod, "unknown") - - def test_enable_mod_not_installed(self): - os.remove(os.path.join( - self.config.parser.root, "mods-available", "ssl.load")) - self.assertRaises( - errors.NoInstallationError, self.config.enable_mod, "ssl") - - def test_enable_mod_files_already_exist(self): - path = os.path.join(self.config.parser.root, "mods-enabled", "ssl.load") - open(path, "w").close() - self.assertRaises( - errors.PluginError, self.config.enable_mod, "ssl") + errors.MisconfigurationError, self.config.enable_mod, "ssl") + @mock.patch("letsencrypt.le_util.run_script") + @mock.patch("letsencrypt.le_util.exe_exists") @mock.patch("letsencrypt_apache.parser.subprocess.Popen") - def test_enable_site(self, mock_popen): + def test_enable_site(self, mock_popen, mock_exe_exists, mock_run_script): mock_popen().returncode = 0 mock_popen().communicate.return_value = ("Define: DUMP_RUN_CFG", "") + mock_exe_exists.return_value = True # Default 443 vhost self.assertFalse(self.vh_truth[1].enabled) self.config.enable_site(self.vh_truth[1]) self.assertTrue(self.vh_truth[1].enabled) + # Mod enabled + self.assertTrue(mock_run_script.called) # Go again to make sure nothing fails self.config.enable_site(self.vh_truth[1]) @@ -216,10 +211,9 @@ class TwoVhost80Test(util.ApacheTest): self.config.enable_site, obj.VirtualHost("asdf", "afsaf", set(), False, False)) - @mock.patch("letsencrypt_apache.parser.subprocess.Popen") - def test_deploy_cert(self, mock_popen): - mock_popen().returncode = 0 - mock_popen().communicate.return_value = ("Define: DUMP_RUN_CFG", "") + def test_deploy_cert(self): + self.config.parser.modules.add("ssl_module") + self.config.parser.modules.add("mod_ssl.c") # Get the default 443 vhost self.config.assoc["random.demo"] = self.vh_truth[1] @@ -399,25 +393,25 @@ class TwoVhost80Test(util.ApacheTest): self.config.cleanup([achall1, achall2]) self.assertTrue(mock_restart.called) - @mock.patch("letsencrypt_apache.configurator.subprocess.Popen") - def test_get_version(self, mock_popen): - mock_popen().communicate.return_value = ( + @mock.patch("letsencrypt.le_util.run_script") + def test_get_version(self, mock_script): + mock_script.return_value = ( "Server Version: Apache/2.4.2 (Debian)", "") self.assertEqual(self.config.get_version(), (2, 4, 2)) - mock_popen().communicate.return_value = ( + mock_script.return_value = ( "Server Version: Apache/2 (Linux)", "") self.assertEqual(self.config.get_version(), (2,)) - mock_popen().communicate.return_value = ( + mock_script.return_value = ( "Server Version: Apache (Debian)", "") self.assertRaises(errors.PluginError, self.config.get_version) - mock_popen().communicate.return_value = ( + mock_script.return_value = ( "Server Version: Apache/2.3{0} Apache/2.4.7".format(os.linesep), "") self.assertRaises(errors.PluginError, self.config.get_version) - mock_popen.side_effect = OSError("Can't find program") + mock_script.side_effect = errors.SubprocessError("Can't find program") self.assertRaises(errors.PluginError, self.config.get_version) @mock.patch("letsencrypt_apache.configurator.subprocess.Popen") @@ -441,23 +435,13 @@ class TwoVhost80Test(util.ApacheTest): self.assertRaises(errors.MisconfigurationError, self.config.restart) - @mock.patch("letsencrypt_apache.configurator.subprocess.Popen") - def test_config_test(self, mock_popen): - mock_popen().communicate.return_value = ("a", "b") - mock_popen().returncode = 0 - + @mock.patch("letsencrypt.le_util.run_script") + def test_config_test(self, _): self.config.config_test() - @mock.patch("letsencrypt_apache.configurator.subprocess.Popen") - def test_config_test_bad_process(self, mock_popen): - mock_popen.side_effect = ValueError - - self.assertRaises(errors.PluginError, self.config.config_test) - - @mock.patch("letsencrypt_apache.configurator.subprocess.Popen") - def test_config_test_failure(self, mock_popen): - mock_popen().communicate.return_value = ("", "") - mock_popen().returncode = -1 + @mock.patch("letsencrypt.le_util.run_script") + def test_config_test_bad_process(self, mock_run_script): + mock_run_script.side_effect = errors.SubprocessError self.assertRaises(errors.MisconfigurationError, self.config.config_test) @@ -497,9 +481,11 @@ class TwoVhost80Test(util.ApacheTest): errors.PluginError, self.config.enhance, "letsencrypt.demo", "unknown_enhancement") - @mock.patch("letsencrypt_apache.parser." - "ApacheParser.update_runtime_variables") - def test_redirect_well_formed_http(self, _): + @mock.patch("letsencrypt.le_util.run_script") + @mock.patch("letsencrypt.le_util.exe_exists") + def test_redirect_well_formed_http(self, mock_exe, _): + self.config.parser.update_runtime_variables = mock.Mock() + mock_exe.return_value = True # This will create an ssl vhost for letsencrypt.demo self.config.enhance("letsencrypt.demo", "redirect") diff --git a/letsencrypt-apache/letsencrypt_apache/tests/dvsni_test.py b/letsencrypt-apache/letsencrypt_apache/tests/dvsni_test.py index ff13fef7b..329a5439b 100644 --- a/letsencrypt-apache/letsencrypt_apache/tests/dvsni_test.py +++ b/letsencrypt-apache/letsencrypt_apache/tests/dvsni_test.py @@ -36,10 +36,11 @@ class DvsniPerformTest(util.ApacheTest): resp = self.sni.perform() self.assertEqual(len(resp), 0) - @mock.patch("letsencrypt_apache.parser.subprocess.Popen") - def test_perform1(self, mock_popen): - mock_popen().communicate.return_value = ("Define: DUMP_RUN_CFG", "") - mock_popen().returncode = 0 + @mock.patch("letsencrypt.le_util.exe_exists") + @mock.patch("letsencrypt.le_util.run_script") + def test_perform1(self, _, mock_exists): + mock_exists.return_value = True + self.sni.configurator.parser.update_runtime_variables = mock.Mock() achall = self.achalls[0] self.sni.add_chall(achall) diff --git a/letsencrypt/errors.py b/letsencrypt/errors.py index 5cc45f000..82331fced 100644 --- a/letsencrypt/errors.py +++ b/letsencrypt/errors.py @@ -5,6 +5,10 @@ class Error(Exception): """Generic Let's Encrypt client error.""" +class SubprocessError(Error): + """Subprocess handling error.""" + + class AccountStorageError(Error): """Generic `.AccountStorage` error.""" diff --git a/letsencrypt/le_util.py b/letsencrypt/le_util.py index e525a333c..af8a56ef5 100644 --- a/letsencrypt/le_util.py +++ b/letsencrypt/le_util.py @@ -4,6 +4,7 @@ import errno import logging import os import re +import subprocess import stat from letsencrypt import errors @@ -17,6 +18,57 @@ Key = collections.namedtuple("Key", "file pem") CSR = collections.namedtuple("CSR", "file data form") +def run_script(params): + """Run the script with the given params. + + :param list params: List of parameters to pass to Popen + + """ + try: + proc = subprocess.Popen(params, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + except (OSError, ValueError): + msg = "Unable to run the command: %s" % " ".join(params) + logger.error(msg) + raise errors.SubprocessError(msg) + + stdout, stderr = proc.communicate() + + if proc.returncode != 0: + msg = "Error while running %s.\n%s\n%s" % ( + " ".join(params), stdout, stderr) + # Enter recovery routine... + logger.error(msg) + raise errors.SubprocessError(msg) + + return stdout, stderr + + +def exe_exists(exe): + """Determine whether path/name refers to an executable. + + :param str exe: Executable path or name + + :returns: If exe is a valid executable + :rtype: bool + + """ + def is_exe(path): + """Determine if path is an exe.""" + return os.path.isfile(path) and os.access(path, os.X_OK) + + path, _ = os.path.split(exe) + if path: + return is_exe(exe) + else: + for path in os.environ["PATH"].split(os.pathsep): + if is_exe(os.path.join(path, exe)): + return True + + return False + def make_or_verify_dir(directory, mode=0o755, uid=0): """Make sure directory exists with proper permissions. diff --git a/letsencrypt/reverter.py b/letsencrypt/reverter.py index a31281a5b..03d62ce13 100644 --- a/letsencrypt/reverter.py +++ b/letsencrypt/reverter.py @@ -1,4 +1,5 @@ """Reverter class saves configuration checkpoints and allows for recovery.""" +import csv import logging import os import shutil @@ -20,6 +21,8 @@ logger = logging.getLogger(__name__) class Reverter(object): """Reverter Class - save and revert configuration checkpoints. + .. note:: Consider moving everything over to CSV format. + :param config: Configuration. :type config: :class:`letsencrypt.interfaces.IConfig` @@ -101,6 +104,7 @@ class Reverter(object): if not backups: logger.info("The Let's Encrypt client has not saved any backups " "of your configuration") + return # Make sure there isn't anything unexpected in the backup folder # There should only be timestamped (float) directories @@ -204,7 +208,7 @@ class Reverter(object): notes_fd.write(save_notes) def _read_and_append(self, filepath): # pylint: disable=no-self-use - """Reads the file lines and returns a fd. + """Reads the file lines and returns a file obj. Read the file returning the lines, and a pointer to the end of the file. @@ -230,6 +234,10 @@ class Reverter(object): :raises errors.ReverterError: If unable to recover checkpoint """ + # Undo all commands + if os.path.isfile(os.path.join(cp_dir, "COMMANDS")): + self._run_undo_commands(os.path.join(cp_dir, "COMMANDS")) + # Revert all changed files if os.path.isfile(os.path.join(cp_dir, "FILEPATHS")): try: with open(os.path.join(cp_dir, "FILEPATHS")) as paths_fd: @@ -254,6 +262,17 @@ class Reverter(object): raise errors.ReverterError( "Unable to remove directory: %s" % cp_dir) + def _run_undo_commands(self, filepath): # pylint: disable=no-self-use + """Run all commands in a file.""" + with open(filepath, 'rb') as csvfile: + csvreader = csv.reader(csvfile) + for command in reversed(list(csvreader)): + try: + le_util.run_script(command) + except errors.SubprocessError: + logger.error( + "Unable to run undo command: %s", " ".join(command)) + def _check_tempfile_saves(self, save_files): """Verify save isn't overwriting any temporary files. @@ -306,13 +325,7 @@ class Reverter(object): raise errors.ReverterError( "Forgot to provide files to registration call") - if temporary: - cp_dir = self.config.temp_checkpoint_dir - else: - cp_dir = self.config.in_progress_dir - - le_util.make_or_verify_dir( - cp_dir, constants.CONFIG_DIRS_MODE, os.geteuid()) + cp_dir = self._get_cp_dir(temporary) # Append all new files (that aren't already registered) new_fd = None @@ -331,6 +344,53 @@ class Reverter(object): if new_fd is not None: new_fd.close() + def register_undo_command(self, temporary, command): + """Register a command to be run to undo actions taken. + + .. warning:: This function does not enforce order of operations in terms + of file modification vs. command registration. All undo commands + are run first before all normal files are reverted to their previous + state. If you need to maintain strict order, you may create + checkpoints before and after the the command registration. This + function may be improved in the future based on demand. + + :param bool temporary: Whether the command should be saved in the + IN_PROGRESS or TEMPORARY checkpoints. + :param command: Command to be run. + :type command: list of str + + """ + commands_fp = os.path.join(self._get_cp_dir(temporary), "COMMANDS") + command_file = None + try: + if os.path.isfile(commands_fp): + command_file = open(commands_fp, "ab") + else: + command_file = open(commands_fp, "wb") + + csvwriter = csv.writer(command_file) + csvwriter.writerow(command) + + except (IOError, OSError): + logger.error("Unable to register undo command") + raise errors.ReverterError( + "Unable to register undo command.") + finally: + if command_file is not None: + command_file.close() + + def _get_cp_dir(self, temporary): + """Return the proper reverter directory.""" + if temporary: + cp_dir = self.config.temp_checkpoint_dir + else: + cp_dir = self.config.in_progress_dir + + le_util.make_or_verify_dir( + cp_dir, constants.CONFIG_DIRS_MODE, os.geteuid()) + + return cp_dir + def recovery_routine(self): """Revert configuration to most recent finalized checkpoint. diff --git a/letsencrypt/tests/le_util_test.py b/letsencrypt/tests/le_util_test.py index 1ecc1ea16..6a6ad3a54 100644 --- a/letsencrypt/tests/le_util_test.py +++ b/letsencrypt/tests/le_util_test.py @@ -11,6 +11,67 @@ import mock from letsencrypt import errors +class RunScriptTest(unittest.TestCase): + """Tests for letsencrypt.le_util.run_script.""" + @classmethod + def _call(cls, params): + from letsencrypt.le_util import run_script + return run_script(params) + + @mock.patch("letsencrypt.le_util.subprocess.Popen") + def test_default(self, mock_popen): + """These will be changed soon enough with reload.""" + mock_popen().returncode = 0 + mock_popen().communicate.return_value = ("stdout", "stderr") + + out, err = self._call(["test"]) + self.assertEqual(out, "stdout") + self.assertEqual(err, "stderr") + + @mock.patch("letsencrypt.le_util.subprocess.Popen") + def test_bad_process(self, mock_popen): + mock_popen.side_effect = OSError + + self.assertRaises(errors.SubprocessError, self._call, ["test"]) + + @mock.patch("letsencrypt.le_util.subprocess.Popen") + def test_failure(self, mock_popen): + mock_popen().communicate.return_value = ("", "") + mock_popen().returncode = 1 + + self.assertRaises(errors.SubprocessError, self._call, ["test"]) + + +class ExeExistsTest(unittest.TestCase): + """Tests for letsencrypt.le_util.exe_exists.""" + + @classmethod + def _call(cls, exe): + from letsencrypt.le_util import exe_exists + return exe_exists(exe) + + @mock.patch("letsencrypt.le_util.os.path.isfile") + @mock.patch("letsencrypt.le_util.os.access") + def test_full_path(self, mock_access, mock_isfile): + mock_access.return_value = True + mock_isfile.return_value = True + self.assertTrue(self._call("/path/to/exe")) + + @mock.patch("letsencrypt.le_util.os.path.isfile") + @mock.patch("letsencrypt.le_util.os.access") + def test_on_path(self, mock_access, mock_isfile): + mock_access.return_value = True + mock_isfile.return_value = True + self.assertTrue(self._call("exe")) + + @mock.patch("letsencrypt.le_util.os.path.isfile") + @mock.patch("letsencrypt.le_util.os.access") + def test_not_found(self, mock_access, mock_isfile): + mock_access.return_value = False + mock_isfile.return_value = True + self.assertFalse(self._call("exe")) + + class MakeOrVerifyDirTest(unittest.TestCase): """Tests for letsencrypt.le_util.make_or_verify_dir. diff --git a/letsencrypt/tests/reverter_test.py b/letsencrypt/tests/reverter_test.py index da57bf8dc..d568d2aef 100644 --- a/letsencrypt/tests/reverter_test.py +++ b/letsencrypt/tests/reverter_test.py @@ -1,4 +1,6 @@ """Test letsencrypt.reverter.""" +import csv +import itertools import logging import os import shutil @@ -11,7 +13,7 @@ from letsencrypt import errors class ReverterCheckpointLocalTest(unittest.TestCase): - # pylint: disable=too-many-instance-attributes + # pylint: disable=too-many-instance-attributes, too-many-public-methods """Test the Reverter Class.""" def setUp(self): from letsencrypt.reverter import Reverter @@ -126,6 +128,42 @@ class ReverterCheckpointLocalTest(unittest.TestCase): errors.ReverterError, self.reverter.register_file_creation, "filepath") + def test_register_undo_command(self): + coms = [ + ["a2dismod", "ssl"], + ["a2dismod", "rewrite"], + ["cleanslate"] + ] + for com in coms: + self.reverter.register_undo_command(True, com) + + act_coms = get_undo_commands(self.config.temp_checkpoint_dir) + + for a_com, com in itertools.izip(act_coms, coms): + self.assertEqual(a_com, com) + + def test_bad_register_undo_command(self): + m_open = mock.mock_open() + with mock.patch("letsencrypt.reverter.open", m_open, create=True): + m_open.side_effect = OSError("bad open") + self.assertRaises( + errors.ReverterError, self.reverter.register_undo_command, + True, ["command"]) + + @mock.patch("letsencrypt.le_util.run_script") + def test_run_undo_commands(self, mock_run): + mock_run.side_effect = ["", errors.SubprocessError] + coms = [ + ["invalid_command"], + ["a2dismod", "ssl"], + ] + for com in coms: + self.reverter.register_undo_command(True, com) + + self.reverter.revert_temporary_config() + + self.assertEqual(mock_run.call_count, 2) + def test_recovery_routine_in_progress_failure(self): self.reverter.add_to_checkpoint(self.sets[0], "perm save") @@ -390,9 +428,9 @@ def setup_test_files(): dir2 = tempfile.mkdtemp("dir2") config1 = os.path.join(dir1, "config.txt") config2 = os.path.join(dir2, "config.txt") - with open(config1, 'w') as file_fd: + with open(config1, "w") as file_fd: file_fd.write("directive-dir1") - with open(config2, 'w') as file_fd: + with open(config2, "w") as file_fd: file_fd.write("directive-dir2") sets = [set([config1]), @@ -404,30 +442,35 @@ def setup_test_files(): def get_save_notes(dire): """Read save notes""" - return read_in(os.path.join(dire, 'CHANGES_SINCE')) + return read_in(os.path.join(dire, "CHANGES_SINCE")) def get_filepaths(dire): """Get Filepaths""" - return read_in(os.path.join(dire, 'FILEPATHS')) + return read_in(os.path.join(dire, "FILEPATHS")) def get_new_files(dire): """Get new files.""" - return read_in(os.path.join(dire, 'NEW_FILES')).splitlines() + return read_in(os.path.join(dire, "NEW_FILES")).splitlines() + + +def get_undo_commands(dire): + """Get new files.""" + return csv.reader(open(os.path.join(dire, "COMMANDS"))) def read_in(path): """Read in a file, return the str""" - with open(path, 'r') as file_fd: + with open(path, "r") as file_fd: return file_fd.read() def update_file(filename, string): """Update a file with a new value.""" - with open(filename, 'w') as file_fd: + with open(filename, "w") as file_fd: file_fd.write(string) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() # pragma: no cover