# -*- coding: utf-8 -*-
from __future__ import unicode_literals, print_function

import argparse
parser = argparse.ArgumentParser(description='Detailed process memory usage accounting tool.')
parser.add_argument('name', nargs='?', help='String to look for in process cmd/binary.')
parser.add_argument('-p', '--private', action='store_true', help='Show only private memory leaks.')
parser.add_argument('-s', '--swap', action='store_true', help='Show only swapped-out stuff.')
parser.add_argument('-n', '--min-val', type=int, default=0,
	help='Minimal (non-inclusive) value for tracked'
		' parameter (KiB, see --swap, --private, default: %(default)s).')
parser.add_argument('-f', '--flat', action='store_true', help='Flat output.')
parser.add_argument('--debug', action='store_true', help='Verbose operation mode.')
optz = parser.parse_args()

import logging
logging.basicConfig(level=logging.DEBUG if optz.debug else logging.INFO)
log = logging.getLogger()

import itertools as it, operator as op, functools as ft
from io import open
try: map, filter, zip, range = it.imap, it.ifilter, it.izip, xrange # py2.X
except AttributeError: unicode = str # py3.X


import os, sys, re
from collections import defaultdict, Iterable
from math import log

proc_path = lambda pid,kind: os.path.join(*map(unicode, ['/proc', pid, kind]))
proc_data = lambda pid,kind: open(proc_path(pid, kind))

def mem_stats( pid,
		page_size = os.sysconf(str('SC_PAGE_SIZE'))/1024, # KiB
		count_re = re.compile(r'^(?P<key>\w+):\s+(?P<val>\d+)\s+(?P<unit>\w+)$') ):
	# rss = int(proc_data(pid, 'statm').readline().split()[1]) * page_size
	span_id, stats = None, dict()
	for line in filter(None, map(op.methodcaller('strip'), proc_data(pid, 'smaps'))):
		opt = count_re.search(line)
		if not opt:
			# b6a35000-b6a46000 r-xp 00000000 fd:01 131449   /lib/libresolv-2.12.1.so (deleted)
			span = line.split(None, 5)
			if len(span) == 6: span_id = check_path(span[-1]) # proper .so, binary, [stack] or [heap]
			elif len(span) == 5: span_id = '[anon]' # anonymous piece
			else: log.warn('Unable to process line: {!r}'.format(line))
		else:
			if opt.group('unit') != 'kB':
				log.warn('Unknown unit spec: {}, skipping line'.format(opt.group('unit')))
				continue
			if span_id not in stats: stats[span_id] = dict()
			(k,v), span_dict = map(opt.group, ['key', 'val']), stats[span_id]
			k = k.lower()
			span_dict[k] = span_dict.get(k, 0) + int(v)
	return dict(stats)

def check_path(path, alt_path=None, _cache=dict()):
	if path not in _cache:
		if path.rsplit(None, 1)[-1] == '(deleted)':
			path = path[:-10]
			suff = 'updated' if os.path.exists(path)\
				or (alt_path and os.path.exists(alt_path)) else 'deleted'
			path = '{} [{}]'.format(path, suff)
		_cache[path] = path
	return _cache[path]

def cmd_name(pid):
	'Will raise OSError on kernel threads or if pid is gone'
	# Note because non-truncated name gets returned,
	#  one can have separated programs as follows:
	# 584.0 KiB +   1.0 MiB =   1.6 MiB	mozilla-thunder (exe -> bash)
	# 56.0 MiB +  22.2 MiB =  78.2 MiB	mozilla-thunderbird-bin

	cmdline = proc_data(pid, 'cmdline').read().split("\0")
	if cmdline[-1] == '' and len(cmdline) > 1: cmdline = cmdline[:-1]
	cmdline = ' '.join(cmdline)
	exe = os.path.realpath(proc_path(pid, 'exe'))

	cmd = os.path.basename(check_path(exe, cmdline[0]))
	exe = proc_data(pid, 'status').readline().split(None, 1)[-1].strip()
	return (exe if not cmd.startswith(exe) else cmd), cmdline # choose least truncated version of exe

def pid_info(pid):
	return cmd_name(pid), mem_stats(pid)

def stats_aggregate(*stats, **kwz):
	sum_swap = kwz.pop('sum_swap', False)
	sum_shared = kwz.pop('sum_shared', False)
	if kwz: raise TypeError(kwz)
	combined = dict(shared=0, private=0, swap=0)
	getter = lambda stat: lambda k: stat.get(k, 0)
	for stat in stats:
		combined['shared'] = (max if not sum_shared else op.add)\
			(combined['shared'], sum(map(
				getter(stat), ['shared', 'shared_clean', 'shared_dirty'] )))
		combined['private'] += sum(map(
			getter(stat), ['private', 'private_clean', 'private_dirty'] ))
		# Not sure whether swap is shared (thus should be maxed) or private (summed)
		combined['swap'] = ( max if not sum_swap
			else op.add )(combined['swap'], stat.get('swap', 0))
	return combined

def mods_aggregate(*mods, **kwz):
	sum_mods = kwz.pop('sum_after', False)
	if kwz: raise TypeError(kwz)
	combined = dict()
	for mods in mods:
		for mod, stats in (mods.items()
				if isinstance(mods, dict) else mods):
			combined[mod] = stats_aggregate(
				stats, combined.get(mod, dict()) )
	return combined if not sum_mods else stats_aggregate(
		*combined.values(), sum_shared=True, sum_swap=True )


### Collect/index data

modules, cmds = dict(), defaultdict(dict)
for pid in map(int, filter(op.methodcaller('isdigit'), os.listdir('/proc'))):
	try: (cmd, cmdline), cmd_stats = cmd_name(pid), mem_stats(pid)
	except (OSError, IOError): continue # kernel thread, old pid, etc
	cmds[cmd][pid] = cmdline

	for mod, stats in cmd_stats.items():
		stats = stats_aggregate(stats)
		if mod not in modules: modules[mod] = dict()
		modules[mod][cmd] = stats_aggregate(stats, modules[mod].get(cmd, dict()))
		modules[mod][pid] = stats
		if pid not in modules: modules[pid] = dict()
		modules[pid][mod] = stats


### Build tree

cmd_tree = dict()
val_filter = lambda stats, val: dict(
	(k,v) for k,v in stats.items()
	if k.startswith('-') or v.get(val, 0) > optz.min_val )

for cmd, pids in cmds.items():
	if optz.name and optz.name not in cmd: continue
	if cmd not in cmd_tree: cmd_tree[cmd] = dict()
	for pid,cmdline in pids.items():
		mods = cmd_tree[cmd][pid] = modules[pid].copy()
		for mod,vals in mods.items():
			vals = mods[mod] = vals.copy()
			if vals['shared']:
				if mod[0] == '[':
					# vals['-shared-with'] = ['forks']
					continue
				vals['-shared-with'] = set(filter(
					lambda x: isinstance(x, unicode), modules[mod] ))
				if len(cmds[cmd]) == 1: vals['-shared-with'].remove(cmd)

for cmd,pids in list(cmd_tree.items()):
	cmd_stats = mods_aggregate(*pids.values(), sum_after=True)

	for pid,mods in list(pids.items()):
		stats = stats_aggregate( *filter(
				lambda x: isinstance(x, dict), mods.values() ),
			sum_shared=True, sum_swap=True )
		if len(cmd_tree[cmd]) > 1: cmd_tree[cmd][pid]['-stats'] = stats
		for filter_val in ['swap', 'private']:
			if getattr(optz, filter_val):
				cmd_tree[cmd][pid] = val_filter(cmd_tree[cmd][pid], filter_val)
				if stats[filter_val] <= optz.min_val:
					del cmd_tree[cmd][pid]
					break
		else:
			cmd_tree[cmd][pid] = dict( (unicode(k), v)
				for k,v in cmd_tree[cmd][pid].items() )
		mods['-cmdline'] = cmds[cmd][pid]

	cmd_tree[cmd]['-stats'] = cmd_stats
	for filter_val in ['swap', 'private']:
		if getattr(optz, filter_val) and stats[filter_val] <= optz.min_val:
			del cmd_tree[cmd]
			break
	else:
		cmd_tree[cmd] = dict( (unicode(k), v)
				for k,v in cmd_tree[cmd].items() )


### Dump

byteunits = 'B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'

def bytesformat(val, offset=0):
	if not val: return val
	exp = int(log(val, 1024))
	return '{:.1f} {}'.format( float(val)
		/ pow(1024, exp), byteunits[exp+offset] )

def results_format(data, path=list(), vspacing=None, flat=False):
	if vspacing is None: vspacing = [2,1] if not flat else 0
	if isinstance(vspacing, int): vspacing = ['\n']*vspacing
	dst, vspace, level = '', None, len(path)
	for k,v in sorted( data.items(),
			key=lambda x: '\0' if x[0][0] == '-' else x[0] ):
		if vspace: dst += vspace
		k_line = '{}{}:'.format('  '*level if not flat else '.'.join(path)+'.', k)
		vspace = None
		if isinstance(v, dict):
			if not flat: dst += k_line
			dst += '{}{}'.format('\n'*int(not flat), results_format( v,
				path=path + [k if '.' not in k and ' ' not in k else "'{}'".format(k)],
				vspacing=vspacing, flat=flat ))
			if vspacing and len(vspacing) > level:
				vspace = vspacing[level]
				vspace = vspace if not isinstance(vspace, int) else '\n'*vspace
		else:
			dst += k_line
			if isinstance(v, (int, float)): v = bytesformat(v, 1)
			elif not isinstance(v, (str, unicode))\
				and isinstance(v, Iterable): v = ', '.join(v)
			dst += ' {}\n'.format(v)
	return dst

try: print(results_format(cmd_tree, flat=optz.flat).strip())
except IOError: pass
