knot-dns/scripts/redis_unalias.py

380 lines
12 KiB
Python
Executable file

#!/usr/bin/env python3
# Copyright (C) CZ.NIC, z.s.p.o. and contributors
# SPDX-License-Identifier: GPL-2.0-or-later
# For more information, see <https://www.knot-dns.cz/>
"""Script for resolving ALIAS records in zones stored in Redis using local resolver."""
# requirements redis[hiredis]
from argparse import ArgumentError, ArgumentParser
from contextlib import contextmanager
from enum import IntEnum
from re import sub
from redis import Redis
from redis.exceptions import ConnectionError, ResponseError, TimeoutError
from socket import AF_INET, AF_INET6, SOCK_DGRAM, gaierror, getaddrinfo, inet_pton
from sys import exit, stderr
class RRType(IntEnum):
A = 1
SOA = 6
AAAA = 28
ALIAS = 65401
class Stats:
new = 0
updated = 0
resolved = 0
ipv4 = 0
ipv6 = 0
not_found = 0
def __str__(self):
return \
f'Zones: {self.new + self.updated}\t(new: {self.new}, updated: {self.updated})\n' + \
f'Aliases: {self.resolved + self.not_found}\t(resolved: {self.resolved}, unknown: {self.not_found})\n' + \
f'Records: {self.ipv4 + self.ipv6}\t(A: {self.resolved}, AAAA: {self.ipv6})'
def arg_parser():
parser = ArgumentParser(
exit_on_error=False,
prog="redis-unalias.py",
description="Resolves ALIAS records in all zones of the input instance "
"into the output instance"
)
# required positional arguments
parser.add_argument(
"input_instance",
type=int,
help="zone instance that converts ALIAS from",
)
parser.add_argument(
"output_instance",
type=int,
help="zone instance that converts ALIAS to",
)
# redis connection
redis_group = parser.add_argument_group("connection")
redis_group.add_argument(
"-a", "--addr",
default="localhost",
help="redis-server address",
)
redis_group.add_argument(
"-p", "--port",
type=int,
default=6379,
help="redis-server port",
)
# TLS settings
tls_group = parser.add_argument_group("TLS")
tls_group.add_argument(
"-t", "--tls",
action="store_true",
help="use transport layer security (TLS)",
)
tls_group.add_argument(
"-C", "--tls-cert",
dest="tls_cert",
help="path to a client certificate",
)
tls_group.add_argument(
"-K", "--tls-key",
dest="tls_key",
help="path to a client key",
)
tls_group.add_argument(
"-A", "--tls-ca",
dest="tls_ca",
help="path to a trusted CA certificates used to verify the server",
)
tls_group.add_argument(
"-i", "--tls-insecure",
action="store_const",
const="none",
default="required",
help="disable client TLS validation",
)
# behavior
behavior_group = parser.add_argument_group("execution")
behavior_group.add_argument(
"-d", "--dry-run",
action="store_true",
help="print the transaction instead of commiting",
)
behavior_group.add_argument(
"-s", "--print-stats",
action="store_true",
help="print statistics at the end",
)
return parser
def bytes_to_int(bytes):
return int.from_bytes(bytes, signed=False)
def int_to_bytes(val):
return bytes([val])
def txn_to_str(txn):
return txn[0] * 10 + txn[1]
def dname_to_str(wire):
res = ""
dname_len = len(wire)
if dname_len == 0:
return res
label_len = 0
for i in range(0, dname_len):
if label_len == 0:
label_len = wire[i]
if len(res) > 0 or dname_len == 1:
res += '.'
continue
c = chr(wire[i])
if c.isalnum() or c == '-':
res += c
else:
res += f"\\{wire[i]:03}"
label_len -= 1
return res
def af_to_rtype(af):
if af == AF_INET:
return RRType.A.value
elif af == AF_INET6:
return RRType.AAAA.value
else:
raise Exception('Unsupported type')
def rdata_to_dname_list(dname):
processing = 0
out = []
while processing < len(dname):
size = int.from_bytes(dname[processing : processing + 2], "little", signed=False)
if size == 0:
break
wire = dname[processing + 2 : processing + size + 2]
out.append(dname_to_str(wire))
processing += size + 2
return out
def get_serial(rdata):
wire_serial = rdata[len(rdata) - 21 : len(rdata) - 17]
return int.from_bytes(wire_serial, byteorder='big', signed=False)
def set_serial(rdata, serial):
b = bytearray(rdata)
b[len(b) - 21 : len(b) - 17] = serial.to_bytes(4, 'big')
return bytes(b)
@contextmanager
def knot_zone_transaction(conn, zone, inst, dryrun):
txn = None
try:
instance = int_to_bytes(inst)
txn = conn.execute_command('KNOT_BIN.ZONE.BEGIN', zone, instance)
yield txn
except:
if txn:
conn.execute_command('KNOT_BIN.ZONE.ABORT', zone, txn)
raise
else:
if dryrun:
zone_str = dname_to_str(zone)
resp = conn.execute_command('KNOT.ZONE.LOAD', zone_str, txn_to_str(txn))
conn.execute_command('KNOT_BIN.ZONE.ABORT', zone, txn)
print(f'=== FULL {zone_str} ===')
for record in resp:
print(*[x.decode() for x in record], sep=' ')
else:
conn.execute_command('KNOT_BIN.ZONE.COMMIT', zone, txn)
@contextmanager
def knot_upd_transaction(conn, zone, inst, dryrun):
txn = None
try:
instance = int_to_bytes(inst)
txn = conn.execute_command('KNOT_BIN.UPD.BEGIN', zone, instance)
yield txn
except:
if txn:
conn.execute_command('KNOT_BIN.UPD.ABORT', zone, txn)
raise
else:
if dryrun:
zone_str = dname_to_str(zone)
resp = conn.execute_command('KNOT.UPD.DIFF', zone_str, txn_to_str(txn))
conn.execute_command('KNOT_BIN.UPD.ABORT', zone, txn)
print(f'=== UPDATE {zone_str} ===')
for diff in resp:
for rem in diff[0]:
print('- ', end='')
print(*[x.decode() for x in rem], sep=' ')
for add in diff[1]:
print('+ ', end='')
print(*[x.decode() for x in add], sep=' ')
else:
if conn.execute_command('KNOT_BIN.UPD.DIFF', zone, txn):
conn.execute_command('KNOT_BIN.UPD.COMMIT', zone, txn)
else:
conn.execute_command('KNOT_BIN.UPD.ABORT', zone, txn)
def store_zone_record(conn, zone, txn, record):
resp = conn.execute_command('KNOT_BIN.ZONE.STORE', zone, txn,
record[0], record[1], record[2], record[3], record[4], "M")
if resp != b'OK':
raise Exception("Failed to store record")
def resolve_zone_record(conn, zone, txn, record):
global stats
for dname in rdata_to_dname_list(record[4]):
try:
resp = getaddrinfo(dname, None, type=SOCK_DGRAM)
for r in resp:
new_record = record[0:1]
if r[0] == AF_INET or r[0] == AF_INET6:
bin = inet_pton(r[0], r[4][0])
size = len(bin).to_bytes(2, byteorder='little', signed=False)
new_record.extend([af_to_rtype(r[0]), record[2], 1, size + bin])
else:
continue
store_zone_record(conn, zone, txn, new_record)
if r[0] == AF_INET:
stats.ipv4 += 1
else:
stats.ipv6 += 1
stats.resolved += 1
except (gaierror, UnicodeEncodeError): # Not found - skip
stats.not_found += 1
def store_upd_record(conn, zone, txn, record):
resp = conn.execute_command('KNOT_BIN.UPD.ADD', zone, txn,
record[0], record[1], record[2], record[3], record[4], "M")
if resp != b'OK':
raise Exception("Failed to insert record")
def remove_upd_record(conn, zone, txn, record):
resp = conn.execute_command('KNOT_BIN.UPD.REM', zone, txn,
record[0], record[1], record[2], record[3], record[4])
if resp != b'OK':
raise Exception("Failed to delete record")
def resolve_upd_record(conn, zone, txn, record):
global stats
for dname in rdata_to_dname_list(record[4]):
try:
resp = getaddrinfo(dname, None, type=SOCK_DGRAM)
for r in resp:
new_record = record[0:1]
if r[0] == AF_INET or r[0] == AF_INET6:
bin = inet_pton(r[0], r[4][0])
size = len(bin).to_bytes(2, byteorder='little', signed=False)
new_record.extend([af_to_rtype(r[0]), record[2], 1, size + bin])
else:
continue
store_upd_record(conn, zone, txn, new_record)
if r[0] == AF_INET:
stats.ipv4 += 1
else:
stats.ipv6 += 1
stats.resolved += 1
except (gaierror, UnicodeEncodeError): # Not found - skip
stats.not_found += 1
def convert_zone_new(conn, zone, input, output, dryrun):
input_resp = conn.execute_command('KNOT_BIN.ZONE.LOAD', zone, int_to_bytes(input))
with knot_zone_transaction(conn, zone, output, dryrun) as txn:
for r in input_resp:
if r[1] == RRType.ALIAS:
resolve_zone_record(conn, zone, txn, r)
else:
store_zone_record(conn, zone, txn, r)
def convert_zone_existing(conn, zone, input, output, dryrun):
input_resp = conn.execute_command('KNOT_BIN.ZONE.LOAD', zone, int_to_bytes(input))
old_resp = conn.execute_command('KNOT_BIN.ZONE.LOAD', zone, int_to_bytes(output))
with knot_upd_transaction(conn, zone, output, dryrun) as txn:
for r in old_resp:
if r[1] == RRType.SOA:
current_soa = r
continue
remove_upd_record(conn, zone, txn, r)
for r in input_resp:
if r[1] == RRType.ALIAS:
resolve_upd_record(conn, zone, txn, r)
else:
if r[1] == RRType.SOA:
input_soa = r
continue
store_upd_record(conn, zone, txn, r)
if conn.execute_command('KNOT_BIN.UPD.DIFF', zone, txn):
remove_upd_record(conn, zone, txn, current_soa)
new_serial = (get_serial(current_soa[4]) + 1) % (2**32)
input_soa[4] = set_serial(input_soa[4], new_serial)
store_upd_record(conn, zone, txn, input_soa)
def convert_zone(conn, zone, input, output, dryrun):
global stats
if not zone[1]:
convert_zone_new(conn, zone[0], input, output, dryrun)
stats.new += 1
else:
convert_zone_existing(conn, zone[0], input, output, dryrun)
stats.updated += 1
def list_zones(conn, input, output):
input_mask = 1 << (input - 1)
output_mask = 1 << (output - 1)
resp = conn.execute_command('KNOT_BIN.ZONE.LIST')
filtered = filter(lambda x: (bytes_to_int(x[1]) & input_mask) != 0, resp)
return map(lambda x: (x[0], (bytes_to_int(x[1]) & output_mask) != 0), filtered)
def main():
global stats
stats = Stats()
args = arg_parser()
try:
conf = args.parse_args()
conn = Redis(
host=conf.addr,
port=conf.port,
ssl=conf.tls,
ssl_certfile=conf.tls_cert,
ssl_keyfile=conf.tls_key,
ssl_ca_certs=conf.tls_ca,
ssl_cert_reqs=conf.tls_insecure,
socket_timeout=5
)
for zone in list_zones(conn, conf.input_instance, conf.output_instance):
convert_zone(conn, zone, conf.input_instance, conf.output_instance, conf.dry_run)
except ConnectionError as e:
err = sub(r'^Error\s+-?\d+\s+', 'Error: ', e.args[0])
print(err, file=stderr)
exit(1)
except (ResponseError, TimeoutError) as e:
print("Error: " + e.args[0], file=stderr)
exit(1)
except ArgumentError as e:
print("Error: " + e.message, file=stderr)
args.print_help()
exit(1)
if conf.print_stats:
print("Statistics\n----------")
print(stats)
if __name__ == "__main__":
main()