#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, print_function

import argparse
parser = argparse.ArgumentParser(
	description='DNS resolver that rewrites'
		' AAAA entries to the ones from supplied domain.')
parser.add_argument('-i', '--ip', default='127.0.0.2',
	help='IP to listen on (default: %(default)s).')
parser.add_argument('-p', '--port', type=int, default=5353,
	help='UDP port number to listen on (default: %(default)s).')
parser.add_argument('-r', '--aaaa-redirect', default='nx.fraggod.net',
	help='Name to query for AAAA instead of supplied one (default: %(default)s).')
parser.add_argument('-c', '--aaaa-cache', type=float, default=600,
	help='Time to cache redirected AAAA query response for (default: %(default)ss).')
parser.add_argument('-f', '--forwarders-file', default='/etc/djbdns/cache/servers/@',
	help='File with IP addresses of DNS servers to forward non-AAAA queries to (default: %(default)s).'
		' Format is one IP per line, empty lines and lines starting with hashmarks are skipped.')
parser.add_argument('--debug', action='store_true', help='Verbose operation mode.')
argz = parser.parse_args()


import itertools as it, operator as op, functools as ft
from twisted.application import service, internet
from twisted.internet import reactor, defer
from twisted.names import client, server, dns
from twisted.python import log
from time import time
import os, sys


class IPv4OnlyResolver(client.Resolver):

	def parseConfig(self, resolvConf):
		servers = list()
		for line in it.imap(op.methodcaller('strip'), resolvConf):
			if not line or line.startswith('#'): continue
			servers.append((line, dns.PORT))
		self.dynServers = servers

	_cache = 0, None
	def cache_aaaa(self, result, success):
		self._cache = time() + argz.aaaa_cache,\
			ft.partial(defer.succeed if success else defer.fail, result)
		return result

	def lookupIPV6Address(self, name, timeout = None):
		cache_ts, cache_scons = self._cache
		if time() > cache_ts:
			log.msg('Performing new AAAA lookup for {}'.format(argz.aaaa_redirect))
			return self._lookup(argz.aaaa_redirect, dns.IN, dns.AAAA, timeout)\
				.addCallback(self.cache_aaaa, True).addErrback(self.cache_aaaa, False)
		else: return cache_scons()


import logging
def logger(event, out=sys.__stderr__,
		level=logging.WARNING if not argz.debug else 0 ):
	'Error-only logging observer. Only stderr logging is supported.'
	ev_level = event.get('logLevel', 0)
	if not ev_level and event['isError']: ev_level = logging.ERROR
	if ev_level >= level:
		print('{} :: {}'.format( event.get('system'),
			log.textFromEventDict(event).replace('\n', '\n\t') ), file=out)
log.addObserver(logger)

# Activate posix caps, if necessary
if os.getuid() > 0:
	try: from fgc.caps import Caps
	except ImportError: pass
	else: Caps.from_process().activate().apply()

protocol = dns.DNSDatagramProtocol(
	server.DNSServerFactory(clients=[
		IPv4OnlyResolver(resolv=argz.forwarders_file) ]) )
reactor.listenUDP(argz.port, protocol, interface=argz.ip)
reactor.run()
