Make Repository a context manager, use decorators for wrapping withs

(Remote)Repository.close() is not a public API anymore, but a private
API. It shall not be used from within other classes than Repository
or it's tests. The proper way is to use a context manager now. However,
for RPC/Remote compatibility with Borg 1.0 it is kept and unchanged.

Repositories are not opened by __init__ now anymore, it is done
by binding it to a context manager. (This SHOULD be compatible both ways
with remote, since opening the repo is handled by a RepositoryServer method)

Decorators @with_repository() and @with_archive simplify
context manager handling and remove unnecessary indentation.
This commit is contained in:
Marian Beermann 2016-03-23 00:41:15 +01:00
parent 77dfcbc31d
commit 7caec0187f
8 changed files with 312 additions and 256 deletions

View file

@ -40,18 +40,56 @@ UMASK_DEFAULT = 0o077
DASHES = '-' * 78
class ToggleAction(argparse.Action):
"""argparse action to handle "toggle" flags easily
def argument(args, str_or_bool):
"""If bool is passed, return it. If str is passed, retrieve named attribute from args."""
if isinstance(str_or_bool, str):
return getattr(args, str_or_bool)
return str_or_bool
toggle flags are in the form of ``--foo``, ``--no-foo``.
the ``--no-foo`` argument still needs to be passed to the
``add_argument()`` call, but it simplifies the ``--no``
detection.
def with_repository(fake=False, create=False, lock=True, exclusive=False, manifest=True, cache=False):
"""
def __call__(self, parser, ns, values, option):
"""set the given flag to true unless ``--no`` is passed"""
setattr(ns, self.dest, not option.startswith('--no-'))
Method decorator for subcommand-handling methods: do_XYZ(self, args, repository, )
If a parameter (where allowed) is a str the attribute named of args is used instead.
:param fake: (str or bool) use None instead of repository, don't do anything else
:param create: create repository
:param lock: lock repository
:param exclusive: (str or bool) lock repository exclusively (for writing)
:param manifest: load manifest and key, pass them as keyword arguments
:param cache: open cache, pass it as keyword argument (implies manifest)
"""
def decorator(method):
@functools.wraps(method)
def wrapper(self, args, **kwargs):
location = args.location # note: 'location' must be always present in args
if argument(args, fake):
return method(self, args, repository=None, **kwargs)
elif location.proto == 'ssh':
repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
else:
repository = Repository(location.path, create=create, exclusive=argument(args, exclusive),
lock_wait=self.lock_wait, lock=lock)
with repository:
if manifest or cache:
kwargs['manifest'], kwargs['key'] = Manifest.load(repository)
if cache:
with Cache(repository, kwargs['key'], kwargs['manifest'],
do_files=getattr(args, 'cache_files', False), lock_wait=self.lock_wait) as cache_:
return method(self, args, repository=repository, cache=cache_, **kwargs)
else:
return method(self, args, repository=repository, **kwargs)
return wrapper
return decorator
def with_archive(method):
@functools.wraps(method)
def wrapper(self, args, repository, key, manifest, **kwargs):
archive = Archive(repository, key, manifest, args.location.archive,
numeric_owner=getattr(args, 'numeric_owner', False), cache=kwargs.get('cache'))
return method(self, args, repository=repository, manifest=manifest, key=key, archive=archive, **kwargs)
return wrapper
class Archiver:
@ -60,14 +98,6 @@ class Archiver:
self.exit_code = EXIT_SUCCESS
self.lock_wait = lock_wait
def open_repository(self, args, create=False, exclusive=False, lock=True):
location = args.location # note: 'location' must be always present in args
if location.proto == 'ssh':
repository = RemoteRepository(location, create=create, lock_wait=self.lock_wait, lock=lock, args=args)
else:
repository = Repository(location.path, create=create, exclusive=exclusive, lock_wait=self.lock_wait, lock=lock)
return repository
def print_error(self, msg, *args):
msg = args and msg % args or msg
self.exit_code = EXIT_ERROR
@ -126,10 +156,10 @@ class Archiver:
"""
return RepositoryServer(restrict_to_paths=args.restrict_to_paths).serve()
def do_init(self, args):
@with_repository(create=True, exclusive=True, manifest=False)
def do_init(self, args, repository):
"""Initialize an empty repository"""
logger.info('Initializing repository at "%s"' % args.location.canonical_path())
repository = self.open_repository(args, create=True, exclusive=True)
key = key_creator(repository, args)
manifest = Manifest(key, repository)
manifest.key = key
@ -139,9 +169,9 @@ class Archiver:
pass
return self.exit_code
def do_check(self, args):
@with_repository(exclusive='repair', manifest=False)
def do_check(self, args, repository):
"""Check repository consistency"""
repository = self.open_repository(args, exclusive=args.repair)
if args.repair:
msg = ("'check --repair' is an experimental feature that might result in data loss." +
"\n" +
@ -158,16 +188,15 @@ class Archiver:
return EXIT_WARNING
return EXIT_SUCCESS
def do_change_passphrase(self, args):
@with_repository()
def do_change_passphrase(self, args, repository, manifest, key):
"""Change repository key file passphrase"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
key.change_passphrase()
return EXIT_SUCCESS
def do_migrate_to_repokey(self, args):
@with_repository(manifest=False)
def do_migrate_to_repokey(self, args, repository):
"""Migrate passphrase -> repokey"""
repository = self.open_repository(args)
manifest_data = repository.get(Manifest.MANIFEST_ID)
key_old = PassphraseKey.detect(repository, manifest_data)
key_new = RepoKey(repository)
@ -180,7 +209,8 @@ class Archiver:
key_new.change_passphrase() # option to change key protection passphrase, save
return EXIT_SUCCESS
def do_create(self, args):
@with_repository(fake='dry_run')
def do_create(self, args, repository, manifest=None, key=None):
"""Create new archive"""
matcher = PatternMatcher(fallback=True)
if args.excludes:
@ -245,8 +275,6 @@ class Archiver:
dry_run = args.dry_run
t0 = datetime.utcnow()
if not dry_run:
repository = self.open_repository(args, exclusive=True)
manifest, key = Manifest.load(repository)
compr_args = dict(buffer=COMPR_BUFFER)
compr_args.update(args.compression)
key.compressor = Compressor(**compr_args)
@ -333,17 +361,15 @@ class Archiver:
status = '-' # dry run, item was not backed up
self.print_file_status(status, path)
def do_extract(self, args):
@with_repository()
@with_archive
def do_extract(self, args, repository, manifest, key, archive):
"""Extract archive contents"""
# be restrictive when restoring files, restore permissions later
if sys.getfilesystemencoding() == 'ascii':
logger.warning('Warning: File system encoding is "ascii", extracting non-ascii filenames will not be supported.')
if sys.platform.startswith(('linux', 'freebsd', 'netbsd', 'openbsd', 'darwin', )):
logger.warning('Hint: You likely need to fix your locale setup. E.g. install locales and use: LANG=en_US.UTF-8')
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
archive = Archive(repository, key, manifest, args.location.archive,
numeric_owner=args.numeric_owner)
matcher, include_patterns = self.build_matcher(args.excludes, args.paths)
@ -397,7 +423,9 @@ class Archiver:
self.print_warning("Include pattern '%s' never matched.", pattern)
return self.exit_code
def do_diff(self, args):
@with_repository()
@with_archive
def do_diff(self, args, repository, manifest, key, archive):
"""Diff contents of two archives"""
def format_bytes(count):
if count is None:
@ -493,9 +521,7 @@ class Archiver:
b'chunks': [],
}, deleted=True)
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
archive1 = Archive(repository, key, manifest, args.location.archive)
archive1 = archive
archive2 = Archive(repository, key, manifest, args.archive2)
can_compare_chunk_ids = archive1.metadata.get(b'chunker_params', False) == archive2.metadata.get(
@ -514,55 +540,52 @@ class Archiver:
self.print_warning("Include pattern '%s' never matched.", pattern)
return self.exit_code
def do_rename(self, args):
@with_repository(exclusive=True, cache=True)
@with_archive
def do_rename(self, args, repository, manifest, key, cache, archive):
"""Rename an existing archive"""
repository = self.open_repository(args, exclusive=True)
manifest, key = Manifest.load(repository)
with Cache(repository, key, manifest, lock_wait=self.lock_wait) as cache:
archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
archive.rename(args.name)
manifest.write()
repository.commit()
cache.commit()
archive.rename(args.name)
manifest.write()
repository.commit()
cache.commit()
return self.exit_code
def do_delete(self, args):
@with_repository(exclusive=True, cache=True)
def do_delete(self, args, repository, manifest, key, cache):
"""Delete an existing repository or archive"""
repository = self.open_repository(args, exclusive=True)
manifest, key = Manifest.load(repository)
with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
if args.location.archive:
archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
stats = Statistics()
archive.delete(stats, progress=args.progress)
manifest.write()
repository.commit(save_space=args.save_space)
cache.commit()
logger.info("Archive deleted.")
if args.stats:
log_multi(DASHES,
stats.summary.format(label='Deleted data:', stats=stats),
str(cache),
DASHES)
else:
if not args.cache_only:
msg = []
msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
for archive_info in manifest.list_archive_infos(sort_by='ts'):
msg.append(format_archive(archive_info))
msg.append("Type 'YES' if you understand this and want to continue: ")
msg = '\n'.join(msg)
if not yes(msg, false_msg="Aborting.", truish=('YES', ),
env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
self.exit_code = EXIT_ERROR
return self.exit_code
repository.destroy()
logger.info("Repository deleted.")
cache.destroy()
logger.info("Cache deleted.")
if args.location.archive:
archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
stats = Statistics()
archive.delete(stats, progress=args.progress)
manifest.write()
repository.commit(save_space=args.save_space)
cache.commit()
logger.info("Archive deleted.")
if args.stats:
log_multi(DASHES,
stats.summary.format(label='Deleted data:', stats=stats),
str(cache),
DASHES)
else:
if not args.cache_only:
msg = []
msg.append("You requested to completely DELETE the repository *including* all archives it contains:")
for archive_info in manifest.list_archive_infos(sort_by='ts'):
msg.append(format_archive(archive_info))
msg.append("Type 'YES' if you understand this and want to continue: ")
msg = '\n'.join(msg)
if not yes(msg, false_msg="Aborting.", truish=('YES', ),
env_var_override='BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'):
self.exit_code = EXIT_ERROR
return self.exit_code
repository.destroy()
logger.info("Repository deleted.")
cache.destroy()
logger.info("Cache deleted.")
return self.exit_code
def do_mount(self, args):
@with_repository()
def do_mount(self, args, repository, manifest, key):
"""Mount archive or an entire repository as a FUSE fileystem"""
try:
from .fuse import FuseOperations
@ -574,29 +597,23 @@ class Archiver:
self.print_error('%s: Mountpoint must be a writable directory' % args.mountpoint)
return self.exit_code
repository = self.open_repository(args)
try:
with cache_if_remote(repository) as cached_repo:
manifest, key = Manifest.load(repository)
if args.location.archive:
archive = Archive(repository, key, manifest, args.location.archive)
else:
archive = None
operations = FuseOperations(key, repository, manifest, archive, cached_repo)
logger.info("Mounting filesystem")
try:
operations.mount(args.mountpoint, args.options, args.foreground)
except RuntimeError:
# Relevant error message already printed to stderr by fuse
self.exit_code = EXIT_ERROR
finally:
repository.close()
with cache_if_remote(repository) as cached_repo:
if args.location.archive:
archive = Archive(repository, key, manifest, args.location.archive)
else:
archive = None
operations = FuseOperations(key, repository, manifest, archive, cached_repo)
logger.info("Mounting filesystem")
try:
operations.mount(args.mountpoint, args.options, args.foreground)
except RuntimeError:
# Relevant error message already printed to stderr by fuse
self.exit_code = EXIT_ERROR
return self.exit_code
def do_list(self, args):
@with_repository()
def do_list(self, args, repository, manifest, key):
"""List archive or repository contents"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
if args.location.archive:
matcher, _ = self.build_matcher(args.excludes, args.paths)
@ -620,7 +637,6 @@ class Archiver:
write = sys.stdout.buffer.write
for item in archive.iter_items(lambda item: matcher.match(item[b'path'])):
write(formatter.format_item(item).encode('utf-8', errors='surrogateescape'))
repository.close()
else:
for archive_info in manifest.list_archive_infos(sort_by='ts'):
if args.prefix and not archive_info.name.startswith(args.prefix):
@ -631,30 +647,27 @@ class Archiver:
print(format_archive(archive_info))
return self.exit_code
def do_info(self, args):
@with_repository(cache=True)
@with_archive
def do_info(self, args, repository, manifest, key, archive, cache):
"""Show archive details such as disk space used"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
with Cache(repository, key, manifest, do_files=args.cache_files, lock_wait=self.lock_wait) as cache:
archive = Archive(repository, key, manifest, args.location.archive, cache=cache)
stats = archive.calc_stats(cache)
print('Name:', archive.name)
print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
print('Hostname:', archive.metadata[b'hostname'])
print('Username:', archive.metadata[b'username'])
print('Time (start): %s' % format_time(to_localtime(archive.ts)))
print('Time (end): %s' % format_time(to_localtime(archive.ts_end)))
print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
print('Number of files: %d' % stats.nfiles)
print()
print(str(stats))
print(str(cache))
stats = archive.calc_stats(cache)
print('Name:', archive.name)
print('Fingerprint: %s' % hexlify(archive.id).decode('ascii'))
print('Hostname:', archive.metadata[b'hostname'])
print('Username:', archive.metadata[b'username'])
print('Time (start): %s' % format_time(to_localtime(archive.ts)))
print('Time (end): %s' % format_time(to_localtime(archive.ts_end)))
print('Command line:', remove_surrogates(' '.join(archive.metadata[b'cmdline'])))
print('Number of files: %d' % stats.nfiles)
print()
print(str(stats))
print(str(cache))
return self.exit_code
def do_prune(self, args):
@with_repository()
def do_prune(self, args, repository, manifest, key):
"""Prune repository archives according to specified rules"""
repository = self.open_repository(args, exclusive=True)
manifest, key = Manifest.load(repository)
archives = manifest.list_archive_infos(sort_by='ts', reverse=True) # just a ArchiveInfo list
if args.hourly + args.daily + args.weekly + args.monthly + args.yearly == 0 and args.within is None:
self.print_error('At least one of the "keep-within", "keep-hourly", "keep-daily", "keep-weekly", '
@ -719,10 +732,9 @@ class Archiver:
print("warning: %s" % e)
return self.exit_code
def do_debug_dump_archive_items(self, args):
@with_repository()
def do_debug_dump_archive_items(self, args, repository, manifest, key):
"""dump (decrypted, decompressed) archive items metadata (not: data)"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
archive = Archive(repository, key, manifest, args.location.archive)
for i, item_id in enumerate(archive.metadata[b'items']):
data = key.decrypt(item_id, repository.get(item_id))
@ -733,10 +745,9 @@ class Archiver:
print('Done.')
return EXIT_SUCCESS
def do_debug_get_obj(self, args):
@with_repository(manifest=False)
def do_debug_get_obj(self, args, repository):
"""get object contents from the repository and write it into file"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
hex_id = args.id
try:
id = unhexlify(hex_id)
@ -753,10 +764,9 @@ class Archiver:
print("object %s fetched." % hex_id)
return EXIT_SUCCESS
def do_debug_put_obj(self, args):
@with_repository(manifest=False)
def do_debug_put_obj(self, args, repository):
"""put file(s) contents into the repository"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
for path in args.paths:
with open(path, "rb") as f:
data = f.read()
@ -766,10 +776,9 @@ class Archiver:
repository.commit()
return EXIT_SUCCESS
def do_debug_delete_obj(self, args):
@with_repository(manifest=False)
def do_debug_delete_obj(self, args, repository):
"""delete the objects with the given IDs from the repo"""
repository = self.open_repository(args)
manifest, key = Manifest.load(repository)
modified = False
for hex_id in args.ids:
try:
@ -788,14 +797,11 @@ class Archiver:
print('Done.')
return EXIT_SUCCESS
def do_break_lock(self, args):
@with_repository(lock=False, manifest=False)
def do_break_lock(self, args, repository):
"""Break the repository lock (e.g. in case it was left by a dead borg."""
repository = self.open_repository(args, lock=False)
try:
repository.break_lock()
Cache.break_lock(repository)
finally:
repository.close()
repository.break_lock()
Cache.break_lock(repository)
return self.exit_code
helptext = {}

View file

@ -77,6 +77,7 @@ class RepositoryServer: # pragma: no cover
if r:
data = os.read(stdin_fd, BUFSIZE)
if not data:
self.repository.close()
return
unpacker.feed(data)
for unpacked in unpacker:
@ -100,6 +101,7 @@ class RepositoryServer: # pragma: no cover
else:
os.write(stdout_fd, msgpack.packb((1, msgid, None, res)))
if es:
self.repository.close()
return
def negotiate(self, versions):
@ -117,6 +119,7 @@ class RepositoryServer: # pragma: no cover
else:
raise PathNotAllowed(path)
self.repository = Repository(path, create, lock_wait=lock_wait, lock=lock)
self.repository.__enter__() # clean exit handled by serve() method
return self.repository.id
@ -164,11 +167,21 @@ class RemoteRepository:
self.id = self.call('open', location.path, create, lock_wait, lock)
def __del__(self):
self.close()
if self.p:
self.close()
assert False, "cleanup happened in Repository.__del__"
def __repr__(self):
return '<%s %s>' % (self.__class__.__name__, self.location.canonical_path())
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.rollback()
self.close()
def borg_cmd(self, args, testing):
"""return a borg serve command line"""
# give some args/options to "borg serve" process as they were given to us
@ -392,6 +405,7 @@ class RepositoryCache(RepositoryNoCache):
super().__init__(repository)
tmppath = tempfile.mkdtemp(prefix='borg-tmp')
self.caching_repo = Repository(tmppath, create=True, exclusive=True)
self.caching_repo.__enter__() # handled by context manager in base class
def close(self):
if self.caching_repo is not None:

View file

@ -59,16 +59,31 @@ class Repository:
self.lock = None
self.index = None
self._active_txn = False
if create:
self.create(self.path)
self.open(self.path, exclusive, lock_wait=lock_wait, lock=lock)
self.lock_wait = lock_wait
self.do_lock = lock
self.do_create = create
self.exclusive = exclusive
def __del__(self):
self.close()
if self.lock:
self.close()
assert False, "cleanup happened in Repository.__del__"
def __repr__(self):
return '<%s %s>' % (self.__class__.__name__, self.path)
def __enter__(self):
if self.do_create:
self.do_create = False
self.create(self.path)
self.open(self.path, self.exclusive, lock_wait=self.lock_wait, lock=self.do_lock)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.rollback()
self.close()
def create(self, path):
"""Create a new empty repository at `path`
"""

View file

@ -8,6 +8,7 @@ import sysconfig
import time
import unittest
from ..xattr import get_all
from ..logger import setup_logging
try:
import llfuse
@ -30,6 +31,9 @@ else:
if sys.platform.startswith('netbsd'):
st_mtime_ns_round = -4 # only >1 microsecond resolution here?
# Ensure that the loggers exist for all tests
setup_logging()
class BaseTestCase(unittest.TestCase):
"""

View file

@ -367,7 +367,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
assert sto.st_atime_ns == atime * 1e9
def _extract_repository_id(self, path):
return Repository(self.repository_path).id
with Repository(self.repository_path) as repository:
return repository.id
def _set_repository_id(self, path, id):
config = ConfigParser(interpolation=None)
@ -375,7 +376,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
config.set('repository', 'id', hexlify(id).decode('ascii'))
with open(os.path.join(path, 'config'), 'w') as fd:
config.write(fd)
return Repository(self.repository_path).id
with Repository(self.repository_path) as repository:
return repository.id
def test_sparse_file(self):
# no sparse file support on Mac OS X
@ -745,8 +747,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
self.cmd('extract', '--dry-run', self.repository_location + '::test.3')
self.cmd('extract', '--dry-run', self.repository_location + '::test.4')
# Make sure both archives have been renamed
repository = Repository(self.repository_path)
manifest, key = Manifest.load(repository)
with Repository(self.repository_path) as repository:
manifest, key = Manifest.load(repository)
self.assert_equal(len(manifest.archives), 2)
self.assert_in('test.3', manifest.archives)
self.assert_in('test.4', manifest.archives)
@ -763,8 +765,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
self.cmd('extract', '--dry-run', self.repository_location + '::test.2')
self.cmd('delete', '--stats', self.repository_location + '::test.2')
# Make sure all data except the manifest has been deleted
repository = Repository(self.repository_path)
self.assert_equal(len(repository), 1)
with Repository(self.repository_path) as repository:
self.assert_equal(len(repository), 1)
def test_delete_repo(self):
self.create_regular_file('file1', size=1024 * 80)
@ -772,6 +774,11 @@ class ArchiverTestCase(ArchiverTestCaseBase):
self.cmd('init', self.repository_location)
self.cmd('create', self.repository_location + '::test', 'input')
self.cmd('create', self.repository_location + '::test.2', 'input')
os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'no'
self.cmd('delete', self.repository_location, exit_code=2)
self.archiver.exit_code = 0
assert os.path.exists(self.repository_path)
os.environ['BORG_DELETE_I_KNOW_WHAT_I_AM_DOING'] = 'YES'
self.cmd('delete', self.repository_location)
# Make sure the repo is gone
self.assertFalse(os.path.exists(self.repository_path))
@ -810,8 +817,8 @@ class ArchiverTestCase(ArchiverTestCaseBase):
self.cmd('init', self.repository_location)
self.cmd('create', '--dry-run', self.repository_location + '::test', 'input')
# Make sure no archive has been created
repository = Repository(self.repository_path)
manifest, key = Manifest.load(repository)
with Repository(self.repository_path) as repository:
manifest, key = Manifest.load(repository)
self.assert_equal(len(manifest.archives), 0)
def test_progress(self):
@ -1045,17 +1052,17 @@ class ArchiverTestCase(ArchiverTestCaseBase):
used = set() # counter values already used
def verify_uniqueness():
repository = Repository(self.repository_path)
for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
data = repository.get(key)
hash = sha256(data).digest()
if hash not in seen:
seen.add(hash)
num_blocks = num_aes_blocks(len(data) - 41)
nonce = bytes_to_long(data[33:41])
for counter in range(nonce, nonce + num_blocks):
self.assert_not_in(counter, used)
used.add(counter)
with Repository(self.repository_path) as repository:
for key, _ in repository.open_index(repository.get_transaction_id()).iteritems():
data = repository.get(key)
hash = sha256(data).digest()
if hash not in seen:
seen.add(hash)
num_blocks = num_aes_blocks(len(data) - 41)
nonce = bytes_to_long(data[33:41])
for counter in range(nonce, nonce + num_blocks):
self.assert_not_in(counter, used)
used.add(counter)
self.create_test_files()
os.environ['BORG_PASSPHRASE'] = 'passphrase'
@ -1122,8 +1129,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
def open_archive(self, name):
repository = Repository(self.repository_path)
manifest, key = Manifest.load(repository)
archive = Archive(repository, key, manifest, name)
with repository:
manifest, key = Manifest.load(repository)
archive = Archive(repository, key, manifest, name)
return archive, repository
def test_check_usage(self):
@ -1141,35 +1149,39 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
def test_missing_file_chunk(self):
archive, repository = self.open_archive('archive1')
for item in archive.iter_items():
if item[b'path'].endswith('testsuite/archiver.py'):
repository.delete(item[b'chunks'][-1][0])
break
repository.commit()
with repository:
for item in archive.iter_items():
if item[b'path'].endswith('testsuite/archiver.py'):
repository.delete(item[b'chunks'][-1][0])
break
repository.commit()
self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', '--repair', self.repository_location, exit_code=0)
self.cmd('check', self.repository_location, exit_code=0)
def test_missing_archive_item_chunk(self):
archive, repository = self.open_archive('archive1')
repository.delete(archive.metadata[b'items'][-5])
repository.commit()
with repository:
repository.delete(archive.metadata[b'items'][-5])
repository.commit()
self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', '--repair', self.repository_location, exit_code=0)
self.cmd('check', self.repository_location, exit_code=0)
def test_missing_archive_metadata(self):
archive, repository = self.open_archive('archive1')
repository.delete(archive.id)
repository.commit()
with repository:
repository.delete(archive.id)
repository.commit()
self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', '--repair', self.repository_location, exit_code=0)
self.cmd('check', self.repository_location, exit_code=0)
def test_missing_manifest(self):
archive, repository = self.open_archive('archive1')
repository.delete(Manifest.MANIFEST_ID)
repository.commit()
with repository:
repository.delete(Manifest.MANIFEST_ID)
repository.commit()
self.cmd('check', self.repository_location, exit_code=1)
output = self.cmd('check', '-v', '--repair', self.repository_location, exit_code=0)
self.assert_in('archive1', output)
@ -1178,10 +1190,9 @@ class ArchiverCheckTestCase(ArchiverTestCaseBase):
def test_extra_chunks(self):
self.cmd('check', self.repository_location, exit_code=0)
repository = Repository(self.repository_location)
repository.put(b'01234567890123456789012345678901', b'xxxx')
repository.commit()
repository.close()
with Repository(self.repository_location) as repository:
repository.put(b'01234567890123456789012345678901', b'xxxx')
repository.commit()
self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', self.repository_location, exit_code=1)
self.cmd('check', '--repair', self.repository_location, exit_code=0)

View file

@ -21,6 +21,7 @@ class RepositoryTestCaseBase(BaseTestCase):
def setUp(self):
self.tmppath = tempfile.mkdtemp()
self.repository = self.open(create=True)
self.repository.__enter__()
def tearDown(self):
self.repository.close()
@ -43,13 +44,12 @@ class RepositoryTestCase(RepositoryTestCaseBase):
self.assert_raises(Repository.ObjectNotFound, lambda: self.repository.get(key50))
self.repository.commit()
self.repository.close()
repository2 = self.open()
self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
for x in range(100):
if x == 50:
continue
self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
repository2.close()
with self.open() as repository2:
self.assert_raises(Repository.ObjectNotFound, lambda: repository2.get(key50))
for x in range(100):
if x == 50:
continue
self.assert_equal(repository2.get(('%-32d' % x).encode('ascii')), b'SOMEDATA')
def test2(self):
"""Test multiple sequential transactions
@ -100,13 +100,14 @@ class RepositoryTestCase(RepositoryTestCaseBase):
self.repository.close()
# replace
self.repository = self.open()
self.repository.put(b'00000000000000000000000000000000', b'bar')
self.repository.commit()
self.repository.close()
with self.repository:
self.repository.put(b'00000000000000000000000000000000', b'bar')
self.repository.commit()
# delete
self.repository = self.open()
self.repository.delete(b'00000000000000000000000000000000')
self.repository.commit()
with self.repository:
self.repository.delete(b'00000000000000000000000000000000')
self.repository.commit()
def test_list(self):
for x in range(100):
@ -139,8 +140,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
if name.startswith('index.'):
os.unlink(os.path.join(self.repository.path, name))
self.reopen()
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
with self.repository:
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
def test_crash_before_compact_segments(self):
self.add_keys()
@ -150,8 +152,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
except TypeError:
pass
self.reopen()
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
with self.repository:
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
def test_replay_of_readonly_repository(self):
self.add_keys()
@ -160,8 +163,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
os.unlink(os.path.join(self.repository.path, name))
with patch.object(UpgradableLock, 'upgrade', side_effect=LockFailed) as upgrade:
self.reopen()
self.assert_raises(LockFailed, lambda: len(self.repository))
upgrade.assert_called_once_with()
with self.repository:
self.assert_raises(LockFailed, lambda: len(self.repository))
upgrade.assert_called_once_with()
def test_crash_before_write_index(self):
self.add_keys()
@ -171,8 +175,9 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
except TypeError:
pass
self.reopen()
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
with self.repository:
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
def test_crash_before_deleting_compacted_segments(self):
self.add_keys()
@ -182,9 +187,10 @@ class RepositoryCommitTestCase(RepositoryTestCaseBase):
except TypeError:
pass
self.reopen()
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
self.assert_equal(len(self.repository), 3)
with self.repository:
self.assert_equal(len(self.repository), 3)
self.assert_equal(self.repository.check(), True)
self.assert_equal(len(self.repository), 3)
class RepositoryCheckTestCase(RepositoryTestCaseBase):
@ -313,8 +319,9 @@ class RepositoryCheckTestCase(RepositoryTestCaseBase):
self.repository.commit()
compact.assert_called_once_with(save_space=False)
self.reopen()
self.check(repair=True)
self.assert_equal(self.repository.get(bytes(32)), b'data2')
with self.repository:
self.check(repair=True)
self.assert_equal(self.repository.get(bytes(32)), b'data2')
class RemoteRepositoryTestCase(RepositoryTestCase):

View file

@ -23,11 +23,9 @@ def repo_valid(path):
:param path: the path to the repository
:returns: if borg can check the repository
"""
repository = Repository(str(path), create=False)
# can't check raises() because check() handles the error
state = repository.check()
repository.close()
return state
with Repository(str(path), create=False) as repository:
# can't check raises() because check() handles the error
return repository.check()
def key_valid(path):
@ -79,11 +77,11 @@ def test_convert_segments(tmpdir, attic_repo, inplace):
"""
# check should fail because of magic number
assert not repo_valid(tmpdir)
repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
segments = [filename for i, filename in repo.io.segment_iterator()]
repo.close()
repo.convert_segments(segments, dryrun=False, inplace=inplace)
repo.convert_cache(dryrun=False)
repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
with repository:
segments = [filename for i, filename in repository.io.segment_iterator()]
repository.convert_segments(segments, dryrun=False, inplace=inplace)
repository.convert_cache(dryrun=False)
assert repo_valid(tmpdir)
@ -138,9 +136,9 @@ def test_keys(tmpdir, attic_repo, attic_key_file):
define above)
:param attic_key_file: an attic.key.KeyfileKey (fixture created above)
"""
repository = AtticRepositoryUpgrader(str(tmpdir), create=False)
keyfile = AtticKeyfileKey.find_key_file(repository)
AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
keyfile = AtticKeyfileKey.find_key_file(repository)
AtticRepositoryUpgrader.convert_keyfiles(keyfile, dryrun=False)
assert key_valid(attic_key_file.path)
@ -167,19 +165,19 @@ def test_convert_all(tmpdir, attic_repo, attic_key_file, inplace):
return stat_segment(path).st_ino
orig_inode = first_inode(attic_repo.path)
repo = AtticRepositoryUpgrader(str(tmpdir), create=False)
# replicate command dispatch, partly
os.umask(UMASK_DEFAULT)
backup = repo.upgrade(dryrun=False, inplace=inplace)
if inplace:
assert backup is None
assert first_inode(repo.path) == orig_inode
else:
assert backup
assert first_inode(repo.path) != first_inode(backup)
# i have seen cases where the copied tree has world-readable
# permissions, which is wrong
assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
with AtticRepositoryUpgrader(str(tmpdir), create=False) as repository:
# replicate command dispatch, partly
os.umask(UMASK_DEFAULT)
backup = repository.upgrade(dryrun=False, inplace=inplace)
if inplace:
assert backup is None
assert first_inode(repository.path) == orig_inode
else:
assert backup
assert first_inode(repository.path) != first_inode(backup)
# i have seen cases where the copied tree has world-readable
# permissions, which is wrong
assert stat_segment(backup).st_mode & UMASK_DEFAULT == 0
assert key_valid(attic_key_file.path)
assert repo_valid(tmpdir)

View file

@ -30,23 +30,23 @@ class AtticRepositoryUpgrader(Repository):
we nevertheless do the order in reverse, as we prefer to do
the fast stuff first, to improve interactivity.
"""
backup = None
if not inplace:
backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
logger.info('making a hardlink copy in %s', backup)
if not dryrun:
shutil.copytree(self.path, backup, copy_function=os.link)
logger.info("opening attic repository with borg and converting")
# now lock the repo, after we have made the copy
self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
segments = [filename for i, filename in self.io.segment_iterator()]
try:
keyfile = self.find_attic_keyfile()
except KeyfileNotFoundError:
logger.warning("no key file found for repository")
else:
self.convert_keyfiles(keyfile, dryrun)
self.close()
with self:
backup = None
if not inplace:
backup = '{}.upgrade-{:%Y-%m-%d-%H:%M:%S}'.format(self.path, datetime.datetime.now())
logger.info('making a hardlink copy in %s', backup)
if not dryrun:
shutil.copytree(self.path, backup, copy_function=os.link)
logger.info("opening attic repository with borg and converting")
# now lock the repo, after we have made the copy
self.lock = UpgradableLock(os.path.join(self.path, 'lock'), exclusive=True, timeout=1.0).acquire()
segments = [filename for i, filename in self.io.segment_iterator()]
try:
keyfile = self.find_attic_keyfile()
except KeyfileNotFoundError:
logger.warning("no key file found for repository")
else:
self.convert_keyfiles(keyfile, dryrun)
# partial open: just hold on to the lock
self.lock = UpgradableLock(os.path.join(self.path, 'lock'),
exclusive=True).acquire()
@ -282,12 +282,13 @@ class BorgRepositoryUpgrader(Repository):
"""convert an old borg repository to a current borg repository
"""
logger.info("converting borg 0.xx to borg current")
try:
keyfile = self.find_borg0xx_keyfile()
except KeyfileNotFoundError:
logger.warning("no key file found for repository")
else:
self.move_keyfiles(keyfile, dryrun)
with self:
try:
keyfile = self.find_borg0xx_keyfile()
except KeyfileNotFoundError:
logger.warning("no key file found for repository")
else:
self.move_keyfiles(keyfile, dryrun)
def find_borg0xx_keyfile(self):
return Borg0xxKeyfileKey.find_key_file(self)