#!/usr/bin/env python
# -*- coding: utf-8 -*-

from optparse import OptionParser
parser = OptionParser(usage='%prog [options]')
parser.add_option('-n', '--no-revert',
	action='store_true', dest='no_revert',
	help=('do not schedule tables revert (in case of ssh lock),'
		' not recommended, since you never know when firewall may lock itself up'))
parser.add_option('-s', '--summary',
	action='store_true', dest='summary',
	help='show diff between old and new tables afterwards')
parser.add_option('-d', '--dump',
	action='store_true', dest='dump',
	help='no changes, just dump resulting tables to stdout')
optz, argz = parser.parse_args()


from string import whitespace as spaces
from subprocess import Popen, PIPE
import os, sys, yaml, re

import logging as log
log.basicConfig(level=log.INFO)


builtins = set([ 'input', 'forward', 'output',
	'prerouting', 'mangle', 'postrouting' ])
extents = {
	'--mac-source':  '-m mac',
	'--state':  '-m state',
	'--src-range':  '-m iprange',
	'--dst-range':  '-m iprange',
	'--dport (\S+,)+\S+':  '-m multiport',
	'--uid-owner':  '-m owner' }
extents = list( (re.compile('(?<=\s)((! )?'+k+')'), '%s \\1'%v)
	for k,v in extents.iteritems() )
pex = re.compile('(?<=-p\s)((\w+/)+\w+)'),\
	re.compile('(?<=port\s)((\d+/)+\d+)') # protocol extension
vmark = re.compile('(\s*-(v[46]))(?=\s|$)') # IP version mark

cfgs = open(os.path.realpath(os.path.splitext(__file__)[0])+'.yaml').read()
cfgs = re.sub(re.compile(' *\\\\\n\s*', re.M), ' ', cfgs)
cfg = yaml.load(cfgs)


class Tables:
	v4, v6 = list(), list()
	v4_ext = v6_ext = None # comment flags (to skip repeating comments)
	v4_mark = re.compile('\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}')
	v6_mark = re.compile('[a-f0-9]{0,4}::([a-f0-9]{1,4}|/)') # far from perfect, but should do
	mark = None

	def append(self, rules, v=None):
		if not v:
			if self.mark: # rule was hand-marked with proto version
				v = self.mark
				self.mark = None
			else: # auto-determine if it's valid for each table
				if not self.v6_mark.search(rules): v = 'v4'
				if not self.v4_mark.search(rules):
					v = None if v else 'v6' # empty value means both tables
		for v in (('v4', 'v6') if not v else (v,)):
			table = getattr(self, v)
			if rules[0] == '#': setattr(self, '%s_ext'%v, rules)
			else:
				ext = getattr(self, '%s_ext'%v)
				if ext:
					table.append(ext)
					setattr(self, '%s_ext'%v, None)
				table.append(rules)

	def fetch(self, v=None):
		str = '\n'.join
		return (str(self.v4), str(self.v6)) if not v else str(getattr(self, v))

core = Tables()


def chainspec(chain):
	# Chain policy specification (like'input-lan/-', 'input/6' or 'input/+')
	if '/' in chain: chain,policy = chain.split('/', 1)
	else: policy = cfg['policy']
	if not policy or policy == '-': policy = 'DROP'
	elif policy.isdigit():
		if policy == '4': policy = ('ACCEPT', 'DROP')
		elif policy == '6': policy = ('DROP', 'ACCEPT')
		else: raise ValueError, 'Incorect policy specification'
	else: policy = 'ACCEPT'

	if '-' in chain: # like 'input-lan', for chain-global interface specification (useful in svc rules)
		if chain.startswith('input'): rule = '-i'
		elif chain.startswith('output'): rule = '-o'
		else: rule, pre = None, ()
		if rule:
			chain, pre = chain.split('-', 1)
			pre = (rule, pre)
	else: pre = ()

	return chain,policy,pre


for table,chainz in cfg['tablez'].iteritems():
	if table != 'nat': add = core.append
	else: add = lambda x: core.append(x, 'v4')
	add('*'+table) # table header (like '*filter')

	try: svc = chainz.pop('svc')
	except KeyError: svc = {}

	# Form chainspec / initial rules, giving chains a 'clean', separated from chainspec, names
	for chain in chainz.keys():
		rulez = chainz[chain]
		del chainz[chain]
		chain, policy, pre = chainspec(chain)
		chainz[chain] = policy,\
			[(pre, [rulez] if isinstance(rulez, str) else rulez)] # only policy from the original chain is used

	# Extend chains w/ svc rules, if any
	if svc:
		cfgt = re.findall('\n(\s+)'+table+':(.+?)\n((\\1)\S+:.*|$)', cfgs, re.S)[0][1]
		ih = {}
		for name,rulez in svc.iteritems():
			indent = re.findall('^(\s+)'+name+':', cfgt, re.M)
			for i in indent:
				i = i.lstrip('\n')
				try:
					if name not in ih[i]: ih[i].append(name)
				except KeyError: ih[i] = [name]
		indent, ih = sorted(ih.iteritems(), key=lambda x: len(x[1]), reverse=True)[0]
		for name in re.findall('^'+indent+'(\S+):', cfgt, re.M):
			if name not in ih: continue
			try: pre = svc[name].iteritems() # full specification (dict w/ chain and rules list)
			except AttributeError: pre = [('input', svc[name])] # it's just a list of rules, defaults to input chain
			for chain,rulez in pre:
				chain, policy, pre = chainspec(chain) # policy here is silently ignored
				rulez = [rulez] if isinstance(rulez, str) else rulez
				chainz[chain][1].append((None, name))
				chainz[chain][1].append((pre, rulez))

	# Form actual tables
	chainz = sorted(chainz.iteritems(), key=lambda x: x[0].lower() in builtins)
	for name,chain in chainz:
		policy,ruleset = chain
		if name.lower() in builtins: name = name.upper()
		else: policy = '-'

		# Policy header (like ':INPUT ACCEPT [0:0]')
		policy_gen = lambda policy: '\n:%s %s '%(name, policy.upper()) + '[0:0]\n'
		try:
			v4,v6 = policy
			core.append(policy_gen(v4), 'v4')
			core.append(policy_gen(v6), 'v6')
		except (TypeError, ValueError): add(policy_gen(policy))

		header = None
		for base,rulez in ruleset:
			if rulez:
				for rule in rulez: # rule mangling

					# Rule base: comment / state extension
					if base == None: # it's a comment: store till first valid rule
						header = '# '+rulez
						break
					elif cfg['stateful'] and rule and '--state'\
							not in rule and  name == 'INPUT' and '--dport' in rule:
						pre = base + ('--state', 'NEW')
					else: pre = base

					try: # check rule for magical, inserted by hand, proto marks
						v, core.mark = vmark.findall(rule)[0]
					except (IndexError, TypeError): pass
					else: rule = rule.replace(v, '') # Strip magic

					# Final rules (like '-A INPUT -j DROP')
					if not rule: rule = ('-j', 'DROP')
					elif len(rule) == 1:
						if rule == 'x': rule = ('-j', 'REJECT')
						elif rule == '<': rule = ('-j', 'RETURN')
						else: rule = ('-j', 'ACCEPT')
					# Rule actions
					elif rule.endswith(' x'): rule = (rule[:-2], '-j', 'REJECT')
					elif rule.endswith(' -'): rule = (rule[:-2], '-j', 'DROP')
					elif rule.endswith(' <'): rule = (rule[:-2], '-j', 'RETURN')
					elif rule.endswith(' |'): rule = (rule[:-2],)
					elif '-j ' not in rule: rule = (rule, '-j', 'ACCEPT')
					# Full rule, no action mangling is necessary
					else: rule = (rule,)

					rule = ' '.join(('-A', name) + pre + rule) # rule composition
					for k,v in extents: # rule extension (for example, adds '-m ...', where necessary)
						if v in rule: continue
						rule = k.sub(v, rule)

					# Protocol extension (clone rule for each proto)
					if rule:
						rules = [rule]
						for ex in pex:
							try:
								rules = list( ex.sub(_ex, rule) for rule in rules
									for _ex in ex.search(rule).groups()[0].split('/') )
							except AttributeError: pass # no matches
						rule = '\n'.join(rules)

					if header: # flush header, since section isn't empty
						add(header)
						header = None

					add(rule) # ta da!

	add('\nCOMMIT\n\n') # table end marker


# Ignore SIGHUP (in case of SSH break)
import signal
signal.signal(signal.SIGHUP, signal.SIG_IGN) # TODO: add instant-restore as a sighup handler?


def pull_table(v):
	table = Popen( cfg['fs']['bin'][v+'_pull'],
		stdout=PIPE, stderr=sys.stderr ).stdout.read()
	stripped = []
	for line in table.splitlines():
		line = line.strip(spaces)
		if not (line.startswith('#') or line.startswith(':')):
			stripped.append(line) # strip comments and chains' packet counters
	return table, '\n'.join(stripped)


for v in ('v4', 'v6'):
	if not optz.dump:
		# Pull the old table, to check if it's similar to new one (no backup needed in that case)
		old_table, old_essence = pull_table(v)

		# Push new table
		iptables = Popen( cfg['fs']['bin'][v+'_push'],
			stdin=PIPE, stdout=sys.stdout, stderr=sys.stderr )
		iptables.stdin.write(core.fetch(v))
		iptables.stdin.close()
		iptables.wait() # wait for process to digest and apply pushed table

		# Pull new table in iptables-save format, to compare against old one
		new_table, new_essence = pull_table(v)

		if old_essence != new_essence:
			# Backup old table in backup.0 slot, rotating the rest of them
			i_slot = None
			for i in sorted(( cfg['fs']['bakz'][v]%i
					for i in xrange(cfg['fs']['bakz']['keep']) ), reverse=True):
				if os.path.exists(i) and i_slot: os.rename(i, i_slot)
				i_slot = i
			else: open(i, 'w').write(old_table)

			# Generate diff, if requested
			if optz.summary:
				log.info('%s table:'%v)
				bak_old = '/tmp/trilobite_old'
				open(bak_old, 'w').write(old_essence+'\n')
				bak_new = '/tmp/trilobite_new'
				open(bak_new, 'w').write(new_essence+'\n')
				Popen( cfg['fs']['bin']['diff'].split(' ') + [bak_old, bak_new],
					stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr ).wait()
				os.unlink(bak_old)
				os.unlink(bak_new)
				sys.stdout.write('\n\n')

			# Schedule table revert if no commit action will be issued (to ensure that tables are in the sane state)
			if not optz.no_revert:
				at = Popen(
					[cfg['fs']['bin']['at'], 'now', '+', str(cfg['fs']['bakz']['delay']), 'minutes'],
					stdin=PIPE, stdout=sys.stdout, stderr=sys.stderr )
				at.stdin.write('%s < %s\n'%(cfg['fs']['bin'][v+'_push'], i)) # restore from latest backup
				at.stdin.close()
				at.wait()

	else:
		log.info('%s table:'%v)
		sys.stdout.write(core.fetch(v)+'\n\n')
