#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import os, sys
import argparse
parser = argparse.ArgumentParser(
	description='Provide private csync2 keys to specified groups.')

parser.add_argument('groups', nargs='*', default=list(),
	help='Host groups or hosts (if -1 option is specified) to unlock.'
		' List all configured sync-hosts/groups to stdout, if none is specified.')
parser.add_argument('-1', '--hosts', action='store_true',
	help='Interpret arguments as individual hosts.')
parser.add_argument('-u', '--user', default='root',
	help='User to log into hosts as (default: %(default)s).')

parser.add_argument('-c', '--csync2-conf', default='../csync2/cluster.cfg',
	help='Csync2 configuration file to parse groups from (empty to skip, default: %(default)s).')
parser.add_argument('-k', '--keyfile', default='../csync2/csync2.key',
	help='Csync2 authorization key to pipe to a hosts (default: %(default)s).'
		' File with same basename plus .gpg extension is also checked.')
parser.add_argument('-r', '--remote-keyfile', default='/var/lib/csync2/privileged.key',
	help='Keyfile location on a remote host (default: %(default)s).')
parser.add_argument('--pipe-restart-delay', type=float, default=0.5,
	help='Delay between restarting keyfile-transfer'
		' ssh processes for a given host (default: %(default)ss).')

parser.add_argument('--debug', action='store_true', help='Verbose operation mode.')
argz = parser.parse_args()

import logging
logging.basicConfig(level=logging.DEBUG if argz.debug else logging.INFO)
log = logging.getLogger()


import itertools as it, operator as op, functools as ft
from subprocess import Popen, PIPE, STDOUT
from collections import defaultdict

if not argz.hosts:
	state, hosts = None, defaultdict(set)
	for line in it.ifilter(None, it.imap(op.methodcaller('strip'), open(argz.csync2_conf))):
		cmd = line.split(None, 1)[0]
		if not state and cmd == 'group': state = line.split()[1]
		if state and cmd == 'host': hosts[state].update(line.strip(';').split()[1:])
	if log.isEnabledFor(logging.DEBUG):
		log.debug('Found groups: {}'.format(dict(hosts.viewitems())))
	argz.hosts = set()
	for group in argz.groups:
		if group not in hosts:
			parser.error( 'Unknown group specified:'
				' {} (known: {})'.format(group, ', '.join(hosts.viewkeys())) )
		argz.hosts.update(hosts[group])
	# TODO: translation for raw hostnames here?

	if not argz.groups:
		from pprint import pprint
		pprint(dict((bytes(k),map(bytes, v)) for k,v in hosts.viewitems()))
		exit()

else: argz.hosts = argz.groups

log.debug('Affected hosts: {}'.format(argz.hosts))


from twisted.internet import protocol, reactor, defer, utils, error
import signal


if os.path.exists(argz.keyfile): key = open(argz.keyfile, 'rb').read()
else:
	gpg = '{}.gpg'.format(argz.keyfile)
	gpg = Popen(['gpg', '--no-tty', '-qd'], stdin=open(gpg, 'rb'), stdout=PIPE)
	key = gpg.stdout.read()
	gpg.wait()
log.debug('Key size: {}'.format(len(key)))


def timeout(secs):
	'Decorator to add timeout to Deferred calls'
	def wrap(func):
		@defer.inlineCallbacks
		def _timeout(*args, **kwargs):
			rawD = func(*args, **kwargs)
			if not isinstance(rawD, defer.Deferred):
				defer.returnValue(rawD)

			timeoutD = defer.Deferred()
			timesUp = reactor.callLater(secs, timeoutD.callback, None)

			try:
				rawResult, timeoutResult = yield defer.DeferredList(
					[rawD, timeoutD], fireOnOneCallback=True, fireOnOneErrback=True, consumeErrors=True )
			except defer.FirstError as e:
				# Only rawD should raise an exception
				assert e.index == 0
				timesUp.cancel()
				e.subFailure.raiseException()
			else:
				# Timeout
				if timeoutD.called:
					rawD.cancel()
					raise error.TimeoutError('elapsed time: {}s'.format(secs))

			# No timeout
			timesUp.cancel()
			defer.returnValue(rawResult)
		return _timeout
	return wrap


class KeyTransfer(protocol.ProcessProtocol):
	_stdout = _stderr = ''
	instances = set()

	def __init__(self, host):
		self.running, self.done = defer.Deferred(), defer.Deferred()
		self.host, self.persistent = host, True

	def connectionMade(self):
		log.debug('Started subprocess (pid: {})'.format(self.transport.pid))
		self.transport.write(key)
		self.transport.closeStdin()
		self.running.callback(self)
		self.instances.add(self)

	def outReceived(self, data): self._stdout += data
	def errReceived(self, data): self._stderr += data

	def processExited(self, stats):
		log.debug('Process termination event'.format(self.transport.pid))
		if stats.value.exitCode:
			log.warn( 'ssh link to host {} exited with code {}{}{}'\
				.format(self.host, stats.value.exitCode,
					'\n--- stdout: \n{}'.format(self._stdout) if self._stdout else '',
					'\n--- stderr: \n{}'.format(self._stderr) if self._stderr else '' ))
		self._stdout = self._stderr = ''
		self.instances.remove(self)
		self.done.callback(self)
		if self.persistent: reactor.callLater(argz.pipe_restart_delay, self.respawn, self.host)

	@classmethod
	def close(cls):
		deferreds = list()
		for proc in cls.instances:
			log.debug('Signaling pid {}'.format(proc.transport.pid))
			proc.persistent = False
			try: proc.transport.signalProcess('TERM')
			except error.ProcessExitedAlready: pass
			else: deferreds.append(proc.done)
		return deferreds

	@classmethod
	def respawn(cls, host):
		cmd = 'ssh', '-T', '{}@{}'.format(argz.user, host),\
			"test -p '{0}' || mkfifo '{0}' && exec cat >'{0}'".format(argz.remote_keyfile)
		log.debug('Starting key transfer process for host {}: {}'.format(host, cmd))
		self = cls(host)
		reactor.spawnProcess(self, cmd[0], map(bytes, cmd), env=os.environ)
		return self.running


@defer.inlineCallbacks
def spawn(host_list):
	log.debug('Starting up')
	yield defer.DeferredList(map(KeyTransfer.respawn, host_list))
	log.info('Connected to all hosts')

@defer.inlineCallbacks
def terminate(sig=None, frm=None):
	log.debug('Terminating')

	hosts = list( proc.host for res,proc in
		(yield defer.DeferredList(KeyTransfer.close())) )

	log.debug('Scraping sockets for hosts: {}'.format(hosts))

	@defer.inlineCallbacks
	def cleaner(host, output=True, iteration=1):
		log.debug('Starting scraper for host "{}", iteration {}'.format(host, iteration))
		while output:
			try:
				output,code,sig = yield timeout(2)(utils.getProcessOutputAndValue)( 'ssh',
					['-T', '{}@{}'.format(argz.user, host), "exec cat '{0}'".format(argz.remote_keyfile)], env=os.environ )
			except error.TimeoutError: break
			if code or sig: break
			iteration += 1

	yield defer.DeferredList(map(cleaner, hosts))

	log.info('Exited cleanly')
	if reactor.running: reactor.stop()


spawn(argz.hosts)
signal.signal(signal.SIGINT, terminate)
reactor.run()
