diff --git a/CHANGES b/CHANGES index bf8d4c5e7..14b576199 100644 --- a/CHANGES +++ b/CHANGES @@ -8,6 +8,7 @@ Version 0.9 (feature release, released on X) +- Remote repository speed and reliability improvements. - Fix sorting of segment names to ignore NFS left over files. (#17) - Fix incorrect display of time (#13) - Improved error handling / reporting. (#12) diff --git a/attic/archive.py b/attic/archive.py index fac656cc1..fe2373866 100644 --- a/attic/archive.py +++ b/attic/archive.py @@ -23,54 +23,64 @@ has_mtime_ns = sys.version >= '3.3' has_lchmod = hasattr(os, 'lchmod') -class ItemIter(object): +class DownloadPipeline: - def __init__(self, unpacker, filter): - self.unpacker = iter(unpacker) - self.filter = filter - self.stack = [] - self.peeks = 0 - self._peek_iter = None + def __init__(self, repository, key): + self.repository = repository + self.key = key - def __iter__(self): - return self + def unpack_many(self, ids, filter=None): + unpacker = msgpack.Unpacker(use_list=False) + for data in self.fetch_many(ids): + unpacker.feed(data) + items = [decode_dict(item, (b'path', b'source', b'user', b'group')) for item in unpacker] + if filter: + items = [item for item in items if filter(item)] + for item in items: + if b'chunks' in item: + self.repository.preload([c[0] for c in item[b'chunks']]) + for item in items: + yield item - def __next__(self): - if self.stack: - item = self.stack.pop(0) - else: - self._peek_iter = None - item = self.get_next() - self.peeks = max(0, self.peeks - len(item.get(b'chunks', []))) - return item - - def get_next(self): - while True: - n = next(self.unpacker) - decode_dict(n, (b'path', b'source', b'user', b'group')) - if not self.filter or self.filter(n): - return n - - def peek(self): - while True: - while not self._peek_iter: - if self.peeks > 100: - raise StopIteration - _peek = self.get_next() - self.stack.append(_peek) - if b'chunks' in _peek: - self._peek_iter = iter(_peek[b'chunks']) - else: - self._peek_iter = None - try: - item = next(self._peek_iter) - self.peeks += 1 - return item - except StopIteration: - self._peek_iter = None + def fetch_many(self, ids, is_preloaded=False): + for id_, data in zip_longest(ids, self.repository.get_many(ids, is_preloaded=is_preloaded)): + yield self.key.decrypt(id_, data) -class Archive(object): +class ChunkBuffer: + BUFFER_SIZE = 1 * 1024 * 1024 + + def __init__(self, cache, key, stats): + self.buffer = BytesIO() + self.packer = msgpack.Packer(unicode_errors='surrogateescape') + self.cache = cache + self.chunks = [] + self.key = key + self.stats = stats + + def add(self, item): + self.buffer.write(self.packer.pack(item)) + + def flush(self, flush=False): + if self.buffer.tell() == 0: + return + self.buffer.seek(0) + chunks = list(bytes(s) for s in chunkify(self.buffer, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed)) + self.buffer.seek(0) + self.buffer.truncate(0) + # Leave the last parital chunk in the buffer unless flush is True + end = None if flush or len(chunks) == 1 else -1 + for chunk in chunks[:end]: + id_, _, _ = self.cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats) + self.chunks.append(id_) + if end == -1: + self.buffer.write(chunks[-1]) + + def is_full(self): + return self.buffer.tell() > self.BUFFER_SIZE + + +class Archive: class DoesNotExist(Error): """Archive {} does not exist""" @@ -85,13 +95,13 @@ class Archive(object): self.repository = repository self.cache = cache self.manifest = manifest - self.items = BytesIO() - self.items_ids = [] self.hard_links = {} self.stats = Statistics() self.name = name self.checkpoint_interval = checkpoint_interval self.numeric_owner = numeric_owner + self.items_buffer = ChunkBuffer(self.cache, self.key, self.stats) + self.pipeline = DownloadPipeline(self.repository, self.key) if create: if name in manifest.archives: raise self.AlreadyExists(name) @@ -128,44 +138,17 @@ class Archive(object): return 'Archive(%r)' % self.name def iter_items(self, filter=None): - unpacker = msgpack.Unpacker(use_list=False) - i = 0 - n = 20 - while True: - items = self.metadata[b'items'][i:i + n] - i += n - if not items: - break - for id, chunk in [(id, chunk) for id, chunk in zip_longest(items, self.repository.get_many(items))]: - unpacker.feed(self.key.decrypt(id, chunk)) - iter = ItemIter(unpacker, filter) - for item in iter: - yield item, iter.peek + for item in self.pipeline.unpack_many(self.metadata[b'items'], filter=filter): + yield item, None def add_item(self, item): - self.items.write(msgpack.packb(item, unicode_errors='surrogateescape')) + self.items_buffer.add(item) now = time.time() if now - self.last_checkpoint > self.checkpoint_interval: self.last_checkpoint = now self.write_checkpoint() - if self.items.tell() > ITEMS_BUFFER: - self.flush_items() - - def flush_items(self, flush=False): - if self.items.tell() == 0: - return - self.items.seek(0) - chunks = list(bytes(s) for s in chunkify(self.items, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed)) - self.items.seek(0) - self.items.truncate() - for chunk in chunks[:-1]: - id, _, _ = self.cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats) - self.items_ids.append(id) - if flush or len(chunks) == 1: - id, _, _ = self.cache.add_chunk(self.key.id_hash(chunks[-1]), chunks[-1], self.stats) - self.items_ids.append(id) - else: - self.items.write(chunks[-1]) + if self.items_buffer.is_full(): + self.items_buffer.flush() def write_checkpoint(self): self.save(self.checkpoint_name) @@ -176,11 +159,11 @@ class Archive(object): name = name or self.name if name in self.manifest.archives: raise self.AlreadyExists(name) - self.flush_items(flush=True) + self.items_buffer.flush(flush=True) metadata = { 'version': 1, 'name': name, - 'items': self.items_ids, + 'items': self.items_buffer.chunks, 'cmdline': sys.argv, 'hostname': socket.gethostname(), 'username': getuser(), @@ -199,6 +182,9 @@ class Archive(object): count, size, csize = self.cache.chunks[id] stats.update(size, csize, count == 1) self.cache.chunks[id] = count - 1, size, csize + def add_file_chunks(chunks): + for id, _, _ in chunks: + add(id) # This function is a bit evil since it abuses the cache to calculate # the stats. The cache transaction must be rolled back afterwards unpacker = msgpack.Unpacker(use_list=False) @@ -209,12 +195,9 @@ class Archive(object): add(id) unpacker.feed(self.key.decrypt(id, chunk)) for item in unpacker: - try: - for id, size, csize in item[b'chunks']: - add(id) + if b'chunks' in item: stats.nfiles += 1 - except KeyError: - pass + add_file_chunks(item[b'chunks']) cache.rollback() return stats @@ -249,9 +232,8 @@ class Archive(object): os.link(source, path) else: with open(path, 'wb') as fd: - ids = [id for id, size, csize in item[b'chunks']] - for id, chunk in zip_longest(ids, self.repository.get_many(ids, peek)): - data = self.key.decrypt(id, chunk) + ids = [c[0] for c in item[b'chunks']] + for data in self.pipeline.fetch_many(ids, is_preloaded=True): fd.write(data) fd.flush() self.restore_attrs(path, item, fd=fd.fileno()) @@ -314,8 +296,8 @@ class Archive(object): start(item) ids = [id for id, size, csize in item[b'chunks']] try: - for id, chunk in zip_longest(ids, self.repository.get_many(ids, peek)): - self.key.decrypt(id, chunk) + for _ in self.pipeline.fetch_many(ids, is_preloaded=True): + pass except Exception: result(item, False) return @@ -323,15 +305,14 @@ class Archive(object): def delete(self, cache): unpacker = msgpack.Unpacker(use_list=False) - for id in self.metadata[b'items']: - unpacker.feed(self.key.decrypt(id, self.repository.get(id))) + for id_, data in zip_longest(self.metadata[b'items'], self.repository.get_many(self.metadata[b'items'])): + unpacker.feed(self.key.decrypt(id_, data)) + self.cache.chunk_decref(id_) for item in unpacker: - try: + if b'chunks' in item: for chunk_id, size, csize in item[b'chunks']: self.cache.chunk_decref(chunk_id) - except KeyError: - pass - self.cache.chunk_decref(id) + self.cache.chunk_decref(self.id) del self.manifest.archives[self.name] self.manifest.write() @@ -385,19 +366,18 @@ class Archive(object): chunks = None if ids is not None: # Make sure all ids are available - for id in ids: - if not cache.seen_chunk(id): + for id_ in ids: + if not cache.seen_chunk(id_): break else: - chunks = [cache.chunk_incref(id, self.stats) for id in ids] + chunks = [cache.chunk_incref(id_, self.stats) for id_ in ids] # Only chunkify the file if needed if chunks is None: with open(path, 'rb') as fd: chunks = [] for chunk in chunkify(fd, WINDOW_SIZE, CHUNK_MASK, CHUNK_MIN, self.key.chunk_seed): chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats)) - ids = [id for id, _, _ in chunks] - cache.memorize_file(path_hash, st, ids) + cache.memorize_file(path_hash, st, [c[0] for c in chunks]) item = {b'path': safe_path, b'chunks': chunks} item.update(self.stat_attrs(st, path)) self.stats.nfiles += 1 diff --git a/attic/cache.py b/attic/cache.py index e1a475761..65a2c7fb8 100644 --- a/attic/cache.py +++ b/attic/cache.py @@ -154,16 +154,14 @@ class Cache(object): archive = msgpack.unpackb(data) decode_dict(archive, (b'name', b'hostname', b'username', b'time')) # fixme: argv print('Analyzing archive:', archive[b'name']) - for id, chunk in zip_longest(archive[b'items'], self.repository.get_many(archive[b'items'])): - data = self.key.decrypt(id, chunk) - add(id, len(data), len(chunk)) + for id_, chunk in zip_longest(archive[b'items'], self.repository.get_many(archive[b'items'])): + data = self.key.decrypt(id_, chunk) + add(id_, len(data), len(chunk)) unpacker.feed(data) for item in unpacker: - try: - for id, size, csize in item[b'chunks']: - add(id, size, csize) - except KeyError: - pass + if b'chunks' in item: + for id_, size, csize in item[b'chunks']: + add(id_, size, csize) def add_chunk(self, id, data, stats): if not self.txn_active: diff --git a/attic/remote.py b/attic/remote.py index b8cefcc68..1e46cf034 100644 --- a/attic/remote.py +++ b/attic/remote.py @@ -7,7 +7,6 @@ import sys from .helpers import Error from .repository import Repository -from .lrucache import LRUCache BUFSIZE = 10 * 1024 * 1024 @@ -71,19 +70,19 @@ class RemoteRepository(object): self.name = name def __init__(self, location, create=False): + self.preload_ids = [] + self.msgid = 0 + self.to_send = b'' + self.cache = {} + self.ignore_responses = set() + self.responses = {} + self.unpacker = msgpack.Unpacker(use_list=False) self.repository_url = '%s@%s:%s' % (location.user, location.host, location.path) self.p = None - self.cache = LRUCache(256) - self.to_send = b'' - self.extra = {} - self.pending = {} - self.unpacker = msgpack.Unpacker(use_list=False) - self.msgid = 0 - self.received_msgid = 0 if location.host == '__testsuite__': args = [sys.executable, '-m', 'attic.archiver', 'serve'] else: - args = ['ssh',] + args = ['ssh'] if location.port: args += ['-p', str(location.port)] if location.user: @@ -99,11 +98,11 @@ class RemoteRepository(object): self.r_fds = [self.stdout_fd] self.x_fds = [self.stdin_fd, self.stdout_fd] - version = self.call('negotiate', (1,)) + version = self.call('negotiate', 1) if version != 1: raise Exception('Server insisted on using unsupported protocol version %d' % version) try: - self.id = self.call('open', (location.path, create)) + self.id = self.call('open', location.path, create) except self.RPCError as e: if e.name == b'DoesNotExist': raise Repository.DoesNotExist(self.repository_url) @@ -113,11 +112,33 @@ class RemoteRepository(object): def __del__(self): self.close() - def call(self, cmd, args, wait=True): - self.msgid += 1 - to_send = msgpack.packb((1, self.msgid, cmd, args)) + def call(self, cmd, *args, **kw): + for resp in self.call_many(cmd, [args], **kw): + return resp + + def call_many(self, cmd, calls, wait=True, is_preloaded=False): + def fetch_from_cache(args): + msgid = self.cache[args].pop(0) + if not self.cache[args]: + del self.cache[args] + return msgid + + calls = list(calls) + waiting_for = [] w_fds = [self.stdin_fd] - while wait or to_send: + while wait or calls: + while waiting_for: + try: + error, res = self.responses.pop(waiting_for[0]) + waiting_for.pop(0) + if error: + raise self.RPCError(error) + else: + yield res + if not waiting_for and not calls: + return + except KeyError: + break r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1) if x: raise Exception('FD exception occured') @@ -127,147 +148,60 @@ class RemoteRepository(object): raise ConnectionClosed() self.unpacker.feed(data) for type, msgid, error, res in self.unpacker: - if msgid == self.msgid: - self.received_msgid = msgid - if error: - raise self.RPCError(error) - else: - return res + if msgid in self.ignore_responses: + self.ignore_responses.remove(msgid) else: - args = self.pending.pop(msgid, None) - if args is not None: - self.cache[args] = msgid, res, error + self.responses[msgid] = error, res if w: - if to_send: - n = os.write(self.stdin_fd, to_send) - assert n > 0 - to_send = memoryview(to_send)[n:] - if not to_send: - w_fds = [] + while not self.to_send and (calls or self.preload_ids) and len(waiting_for) < 100: + if calls: + if is_preloaded: + if calls[0] in self.cache: + waiting_for.append(fetch_from_cache(calls.pop(0))) + else: + args = calls.pop(0) + if cmd == 'get' and args in self.cache: + waiting_for.append(fetch_from_cache(args)) + else: + self.msgid += 1 + waiting_for.append(self.msgid) + self.to_send = msgpack.packb((1, self.msgid, cmd, args)) + if not self.to_send and self.preload_ids: + args = (self.preload_ids.pop(0),) + self.msgid += 1 + self.cache.setdefault(args, []).append(self.msgid) + self.to_send = msgpack.packb((1, self.msgid, cmd, args)) - def _read(self): - data = os.read(self.stdout_fd, BUFSIZE) - if not data: - raise Exception('Remote host closed connection') - self.unpacker.feed(data) - to_yield = [] - for type, msgid, error, res in self.unpacker: - self.received_msgid = msgid - args = self.pending.pop(msgid, None) - if args is not None: - self.cache[args] = msgid, res, error - for args, resp, error in self.extra.pop(msgid, []): - if not resp and not error: - resp, error = self.cache[args][1:] - to_yield.append((resp, error)) - for res, error in to_yield: - if error: - raise self.RPCError(error) - else: - yield res - - def gen_request(self, cmd, argsv, wait): - data = [] - m = self.received_msgid - for args in argsv: - # Make sure to invalidate any existing cache entries for non-get requests - if not args in self.cache: - self.msgid += 1 - msgid = self.msgid - self.pending[msgid] = args - self.cache[args] = msgid, None, None - data.append(msgpack.packb((1, msgid, cmd, args))) - if wait: - msgid, resp, error = self.cache[args] - m = max(m, msgid) - self.extra.setdefault(m, []).append((args, resp, error)) - return b''.join(data) - - def gen_cache_requests(self, cmd, peek): - data = [] - while True: - try: - args = (peek()[0],) - except StopIteration: - break - if args in self.cache: - continue - self.msgid += 1 - msgid = self.msgid - self.pending[msgid] = args - self.cache[args] = msgid, None, None - data.append(msgpack.packb((1, msgid, cmd, args))) - return b''.join(data) - - def call_multi(self, cmd, argsv, wait=True, peek=None): - w_fds = [self.stdin_fd] - left = len(argsv) - data = self.gen_request(cmd, argsv, wait) - self.to_send += data - for args, resp, error in self.extra.pop(self.received_msgid, []): - left -= 1 - if not resp and not error: - resp, error = self.cache[args][1:] - if error: - raise self.RPCError(error) - else: - yield resp - while left: - r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1) - if x: - raise Exception('FD exception occured') - if r: - for res in self._read(): - left -= 1 - yield res - if w: - if not self.to_send and peek: - self.to_send = self.gen_cache_requests(cmd, peek) if self.to_send: - n = os.write(self.stdin_fd, self.to_send) - assert n > 0 -# self.to_send = memoryview(self.to_send)[n:] - self.to_send = self.to_send[n:] - else: + self.to_send = self.to_send[os.write(self.stdin_fd, self.to_send):] + if not self.to_send and not (calls or self.preload_ids): w_fds = [] - if not wait: - return + self.ignore_responses |= set(waiting_for) def commit(self, *args): - self.call('commit', args) + return self.call('commit') def rollback(self, *args): - self.cache.clear() - self.pending.clear() - self.extra.clear() - return self.call('rollback', args) + return self.call('rollback') - def get(self, id): + def get(self, id_): + for resp in self.get_many([id_]): + return resp + + def get_many(self, ids, is_preloaded=False): try: - for res in self.call_multi('get', [(id, )]): - return res + for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded): + yield resp except self.RPCError as e: if e.name == b'DoesNotExist': raise Repository.DoesNotExist(self.repository_url) raise - def get_many(self, ids, peek=None): - return self.call_multi('get', [(id, ) for id in ids], peek=peek) + def put(self, id_, data, wait=True): + return self.call('put', id_, data, wait=wait) - def _invalidate(self, id): - key = (id, ) - if key in self.cache: - self.pending.pop(self.cache.pop(key)[0], None) - - def put(self, id, data, wait=True): - resp = self.call('put', (id, data), wait=wait) - self._invalidate(id) - return resp - - def delete(self, id, wait=True): - resp = self.call('delete', (id, ), wait=wait) - self._invalidate(id) - return resp + def delete(self, id_, wait=True): + return self.call('delete', id_, wait=wait) def close(self): if self.p: @@ -275,3 +209,6 @@ class RemoteRepository(object): self.p.stdout.close() self.p.wait() self.p = None + + def preload(self, ids): + self.preload_ids += ids diff --git a/attic/repository.py b/attic/repository.py index 04920a6dd..6555e9a8a 100644 --- a/attic/repository.py +++ b/attic/repository.py @@ -220,9 +220,9 @@ class Repository(object): except KeyError: raise self.DoesNotExist(self.path) - def get_many(self, ids, peek=None): - for id in ids: - yield self.get(id) + def get_many(self, ids, is_preloaded=False): + for id_ in ids: + yield self.get(id_) def put(self, id, data, wait=True): if not self._active_txn: @@ -261,6 +261,10 @@ class Repository(object): def add_callback(self, cb, data): cb(None, None, data) + def preload(self, ids): + """Preload objects (only applies to remote repositories + """ + class LoggedIO(object): diff --git a/attic/testsuite/archive.py b/attic/testsuite/archive.py new file mode 100644 index 000000000..25e8bb64e --- /dev/null +++ b/attic/testsuite/archive.py @@ -0,0 +1,32 @@ +import msgpack +from attic.testsuite import AtticTestCase +from attic.archive import ChunkBuffer +from attic.key import PlaintextKey + + +class MockCache: + + def __init__(self): + self.objects = {} + + def add_chunk(self, id, data, stats=None): + self.objects[id] = data + return id, len(data), len(data) + + +class ChunkBufferTestCase(AtticTestCase): + + def test(self): + data = [{b'foo': 1}, {b'bar': 2}] + cache = MockCache() + key = PlaintextKey() + chunks = ChunkBuffer(cache, key, None) + for d in data: + chunks.add(d) + chunks.flush() + chunks.flush(flush=True) + self.assert_equal(len(chunks.chunks), 2) + unpacker = msgpack.Unpacker() + for id in chunks.chunks: + unpacker.feed(cache.objects[id]) + self.assert_equal(data, list(unpacker))