#!/usr/bin/python2
'''
(c) Copyright 2012-2019 Xilinx, Inc. All rights reserved.

This file contains confidential and proprietary information
of Xilinx, Inc. and is protected under U.S. and
international copyright and other intellectual property
laws.

DISCLAIMER
This disclaimer is not a license and does not grant any
rights to the materials distributed herewith. Except as
otherwise provided in a valid license issued to you by
Xilinx, and to the maximum extent permitted by applicable
law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
(2) Xilinx shall not be liable (whether in contract or tort,
including negligence, or under any other theory of
liability) for any loss or damage of any kind or nature
related to, arising under or in connection with these
materials, including for any direct, or any indirect,
special, incidental, or consequential loss or damage
(including loss of data, profits, goodwill, or any type of
loss or damage suffered as a result of any action brought
by a third party) even if such damage or loss was
reasonably foreseeable or Xilinx had been advised of the
possibility of the same.

CRITICAL APPLICATIONS
Xilinx products are not designed or intended to be fail-
safe, or for use in any application requiring fail-safe
performance, such as life-support or safety devices or
systems, Class III medical devices, nuclear facilities,
applications related to the deployment of airbags, or any
other applications that could lead to death, personal
injury, or severe property or environmental damage
(individually and collectively, "Critical
Applications"). Customer assumes the sole risk and
liability of any use of Xilinx products in Critical
Applications, subject only to applicable laws and
regulations governing limitations on product liability.

THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
PART OF THIS FILE AT ALL TIMES
'''

import sys, os, pwd, errno, copy, time, re, operator, optparse, datetime

top = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), '..'))
if os.path.exists(os.path.join(top, 'src', 'python', 'solar_capture')):
    sys.path.insert(0, os.path.join(top, 'src', 'python'))
import solar_capture as sc
import solar_capture.stats as stats
import solar_capture.tabulate as tab


usage_text = '''
  solar_capture_monitor [options] [sessions] [commands]

Examples:
  solar_capture_monitor                - List running sessions
  solar_capture_monitor dump           - Dump running sessions
  solar_capture_monitor 24351 dump     - Dump session(s) for given pid

Commands:
  dump                 - Dump complete state of session
  list                 - List pid and user-id of instance
  nodes                - Dump table of nodes with packet counts
  nodes_rate           - Continuously updated table of nodes with packet rates
  line_rate            - Line-by-line output with packet rate and bandwidth
  line_total           - Line-by-line output with packet and byte counts
  poke obj.attr=val    - Overwrite an object attribute
  dot [options]        - Output topology graph in graphviz format

Sessions:
  pid                  - All sessions for the given process
  pid/session_id       - Specific session from the given process
  directory            - Log directory for a session

  If no sessions are specified, then all running sessions belonging to the
  user are selected.\
'''


options = None
opt_parser = None


def usage_err(msg):
    opt_parser.error(msg)


def out(msg):
    sys.stdout.write(msg)

def err(msg):
    sys.stderr.write(msg)

def fail(rc, msg):
    err(msg)
    sys.exit(rc)


def get_user(uid):
    try:
        return pwd.getpwuid(int(uid)).pw_name
    except:
        return str(uid)


def get_uid(user):
    if user == 'any' or user == 'all':
        return None
    try:
        return pwd.getpwnam(user).pw_uid
    except:
        fail(3, "ERROR: Unknown user '%s'\n" % user)

######################################################################

def fmt_n_pkts(n_pkts):
    return str(n_pkts).rjust(12)


def fmt_n_bytes(n_bytes):
    return str(n_bytes).rjust(15)


def fmt_pkts_per_sec(pps):
    cutoff = 1000000
    cutoff_exp = 6
    decimal_divider = 1000
    if pps < cutoff:
        s = str(pps)
    else:
        s = '%d.%.3de%d' % \
            (pps / cutoff, (pps % cutoff) / decimal_divider, cutoff_exp)
    return s.rjust(7)


def fmt_bw_mbps(bytes_per_sec):
    mbits = bytes_per_sec * 8.0 / 1e6
    if mbits >= 10000.0:
        s = '%d' % mbits
    elif mbits >= 1000.0:
        s = '%.2f' % mbits
    elif mbits >= 100.0:
        s = '%.3f' % mbits
    elif mbits >= 10.0:
        s = '%.4f' % mbits
    elif mbits > 0.0:
        s = '%.5f' % mbits
    else:
        s = '0'
    return s.rjust(7)


def fmt_usec(sec):
    return int(sec * 1e6)


def fmt_secs_as_date(secs):
    if not options.localtime:
        st = time.gmtime(secs)
    else:
        st = time.localtime(secs)
    if options.strftime:
        return time.strftime(options.strftime, st)
    else:
        s = time.strftime("%Y%m%d-%H:%M:%S", st)
        return "%s.%03d" % (s, (secs - int(secs)) * 1000)


class GetterBase(object):
    """Sub-classes should be callable, and should have get_updater() and
    get_label() methods."""
    pass


class StaticGetter(GetterBase):
    def __init__(self, val, label=''):
        self.val = val
        self.label = label
    def __call__(self):
        return self.val
    def get_updater(self):
        return None
    def get_label(self):
        return self.label


class ValGetter(GetterBase):
    def __init__(self, obj, field, context=None, label=None):
        self.obj = obj
        self.field = field
        if label is not None:
            self.label = label
        else:
            self.label = field
    def __call__(self):
        return getattr(self.obj, self.field)
    def get_updater(self):
        if hasattr(self.obj, 'update_fields') and \
                callable(self.obj.update_fields):
            return self.obj.update_fields
        else:
            return None
    def get_label(self):
        return self.label


class RateGetter(GetterBase):
    def __init__(self, getter, context=None, label=None):
        assert context is not None
        self.ctx = context
        self.getter = getter
        if label is None:
            self.label = getter.get_label()
        else:
            self.label = label
    def __call__(self):
        if not hasattr(self, 'v_prev'):
            self.v_prev = self.getter()
        v_now = self.getter()
        rate = (v_now - self.v_prev) / self.ctx['time_delta']
        rate = type(v_now)(rate)  # preserve type
        self.v_prev = v_now
        return rate
    def get_updater(self):
        return self.getter.get_updater()
    def get_label(self):
        return self.getter.get_label()


class FormattedGetter(GetterBase):
    def __init__(self, getter, formatter, context=None, label=None):
        self.getter = getter
        self.formatter = formatter
        if label is None:
            self.label = getter.get_label()
        else:
            self.label = label
    def __call__(self):
        return self.formatter(self.getter())
    def get_updater(self):
        return self.getter.get_updater()
    def get_label(self):
        return self.label


class TimeGetter(GetterBase):
    def __init__(self, context=None, label='time'):
        self.ctx = context
        self.label = label
    def __call__(self):
        return self.ctx['time_now']
    def get_updater(self):
        return None
    def get_label(self):
        return self.label


class ElapsedGetter(GetterBase):
    def __init__(self, context=None, label='elapsed'):
        assert context is not None
        self.ctx = context
        self.label = label
    def __call__(self):
        return self.ctx['time_now'] - self.ctx['time_start']
    def get_updater(self):
        return None
    def get_label(self):
        return self.label


class IntervalGetter(GetterBase):
    def __init__(self, context=None, label='interval'):
        assert context is not None
        self.ctx = context
        self.label = label
    def __call__(self):
        return self.ctx['time_delta']
    def get_updater(self):
        return None
    def get_label(self):
        return self.label


# ?? any advantage to having a mk_static_getter() ???


def mk_static_getter(val, label=''):
    return StaticGetter(val, label=label)


def mk_objid(obj, label='id'):
    return StaticGetter(obj.obj_id, label=label)


def mk_getter(obj, field, label=None):
    return ValGetter(obj, field, label=label)


def mk_rate_getter(obj, field, ctx):
    return RateGetter(ValGetter(obj, field), context=ctx)


def mk_n_pkts_getter(obj, field):
    return FormattedGetter(ValGetter(obj, field), fmt_n_pkts)


def mk_n_bytes_getter(obj, field):
    return FormattedGetter(ValGetter(obj, field), fmt_n_bytes)


def mk_pps_getter(ctx, obj, field):
    return FormattedGetter(RateGetter(ValGetter(obj, field), context=ctx),
                           fmt_pkts_per_sec)


def mk_bw_mbps_getter(ctx, obj, field):
    return FormattedGetter(RateGetter(ValGetter(obj, field), context=ctx),
                           fmt_bw_mbps)


def mk_latency_getter(obj, field):
    return FormattedGetter(ValGetter(obj, field), fmt_usec)


def mk_date_getter(ctx):
    return FormattedGetter(TimeGetter(ctx), fmt_secs_as_date, label='date')


def mk_elapsed_getter(ctx, unit='ms'):
    if unit == 's':
        return ElapsedGetter(ctx)
    elif unit == 'ms':
        fmt = lambda x: int(x * 1e3)
    elif unit == 'us':
        fmt = lambda x: int(x * 1e6)
    elif unit == 'ns':
        fmt = lambda x: int(x * 1e9)
    else:
        raise AssertionError("mk_elapsed_getter: bad unit '%s'" % unit)
    return FormattedGetter(ElapsedGetter(ctx), fmt)


def mk_grid_map(objects, fields, mk_getter=mk_getter):
    return [[mk_getter(o, f) for f in fields] for o in objects]


def grid_snapshot(o):
    if callable(o):
        return str(o())
    elif hasattr(o, '__iter__'):
        return [grid_snapshot(i) for i in o]
    else:
        return str(o)


def grid_rate(rows1, rows2, t_diff=1.0, cols=None):
    d = copy.deepcopy(rows1)
    for ri in range(len(rows1)):
        if cols == None:
            cols = range(len(rows1[ri]))
        for ci in cols:
            v1 = rows1[ri][ci]
            v2 = rows2[ri][ci]
            tmp = (v1 - v2) / t_diff
            if type(v1) == type(v2):
                tmp = type(v1)(tmp)
            d[ri][ci] = tmp
    return d


def __find_updaters(o):
    if hasattr(o, 'get_updater'):
        u = o.get_updater()
        if u:
            return [u]
        else:
            return []
    elif hasattr(o, '__iter__'):
        return reduce(operator.add, [__find_updaters(i) for i in o], [])
    else:
        return []


def find_updaters(getters):
    return set(__find_updaters(getters))


def update_fields(updaters):
    for u in updaters:
        u()


def table_generator(grid_map, col_headings=None, context=dict()):
    updaters = find_updaters(grid_map)
    widths = [0] * len(grid_map[0])
    time_prev = time.time()
    context['time_start'] = time_prev
    while True:
        update_fields(updaters)
        time_now = time.time()
        context['time_now'] = time_now
        context['time_delta'] = time_now - time_prev
        time_prev = time_now
        table = grid_snapshot(grid_map)
        if col_headings:
            table = [col_headings] + table
        yield tab.fmt_table(table, col_widths=widths)


def line_generator(headers, grid_map, context=dict()):
    updaters = find_updaters(grid_map)
    time_prev = time.time()
    context['time_start'] = time_prev
    done_headers = False
    while True:
        update_fields(updaters)
        time_now = time.time()
        context['time_now'] = time_now
        context['time_delta'] = time_now - time_prev
        time_prev = time_now
        fields = grid_snapshot(grid_map)

        output = ''
        if not done_headers:
            done_headers = True
            if headers:
                widths = [max(10, len(headers[i]), len(fields[i])) + 2
                          for i in range(len(headers))]
                output += '#' + ''.join([headers[i].strip().rjust(widths[i])
                                         for i in range(len(headers))]) + '\n'
            else:
                widths = [max(10, len(fields[i])) + 2
                          for i in range(len(fields))]

        output += ' ' + ''.join([fields[i].strip().rjust(widths[i])
                                 for i in range(len(fields))]) + '\n'
        yield output


def grid_gen_rate(grid_map, objects, col_headings=None, rate_cols=None):
    widths = [0] * len(grid_map[0])
    data_prev = grid_snapshot(grid_map)
    t_prev = time.time()
    while True:
        for o in objects:  # fixme: use find_updaters()
            o.update_fields()
        data_now = grid_snapshot(grid_map)
        t_now = time.time()
        t_diff = t_now - t_prev
        diff = grid_rate(data_now, data_prev, t_diff=t_diff, cols=rate_cols)
        if col_headings:
            table = [col_headings] + diff
        else:
            table = diff
        yield tab.fmt_table(table, col_widths=widths)
        data_prev = data_now
        t_prev = t_now


def periodic_writer(content_generator, stream):
    try:
        t_next_wake = time.time()
        while True:
            stream.write(content_generator.next())
            stream.flush()
            time_now = time.time()
            while t_next_wake - time_now <= 0.0:
                t_next_wake += options.interval
            time.sleep(t_next_wake - time.time())
    except KeyboardInterrupt:
        out('\n')
        sys.exit()


def curses_app(stdscr, curses, content_generator, refresh=1.0):
    curses.use_default_colors()
    win = stdscr
    win.nodelay(True)
    done = False
    while not done:
        win.erase()
        win.addstr(0, 0, content_generator.next())
        win.refresh()
        # Wait [refresh] secs, but respond to key-press within 1/10th of sec.
        for i in range(int(refresh / 0.1)):
            time.sleep(0.1)
            key = win.getch()
            if key >= 0 and key != curses.KEY_RESIZE:
                done = True
                break

def run_curses_app(*args, **kwargs):
    import curses
    try:
        curses.wrapper(curses_app, curses, *args, **kwargs)
    except curses.error:
        sys.stderr.write("ERROR: Failed to print output table. "
                         "Terminal window may be too small\n")
        sys.exit(1)

######################################################################

def cmp_objects(a, b):
    if a.thread_id != b.thread_id:
        return a.thread_id - b.thread_id
    if a.type_order != b.type_order:
        return a.type_order - b.type_order
    assert type(a) == type(b)
    if isinstance(a, stats.Node):
        return a.dispatch_order - b.dispatch_order
    return a.id - b.id


def infos_get_is_running(infos):
    for i in infos:
        i['is_running'] = stats.is_running(i)


def infos_filter_by_running(infos, running, stopped, all):
    # Filter out sessions that are/are not running.
    if not running and not stopped:
        running = True
    def wanted(i):
        is_running = i['is_running']
        return (all or (running and is_running == 1) or
                       (stopped and is_running == 0))
    return filter(wanted, infos)

######################################################################

def is_running_str(i):
    if i < 0:
        return 'unknown'
    elif i > 0:
        return 'yes'
    else:
        return 'no'

def do_list(session_infos, strm):
    strm.write('#%-10s %-10s %-7s %s\n' % \
                   ('user', 'pid/id', 'running', 'log-directory'))
    for inf in session_infos:
        user = get_user(inf['uid'])
        pid_id = "%s/%s" % (inf['pid'], inf['id'])
        strm.write('%-11s %-10s %-7s %s\n' % \
                       (user, pid_id, is_running_str(inf['is_running']),
                        inf['dir']))


def action_dump(sessions):
    for session in sessions:
        # ?? fixme: python 3 requires key=cmp_to_key()
        objs = sorted(session.object_list, cmp=cmp_objects)
        out('Dump:\n')
        dt = datetime.datetime.now()
        out("  %-30s %s\n" % ('date', dt.date()))
        out("  %-30s %s\n" % ('time', dt.time()))
        out('\n')
        for o in objs:
            out("%s:\n" % (o.type_name))
            fields = o.field_names()
            if 'name' not in fields:
                fields = ['name'] + fields
            for f in fields:
                out("  %-30s %s\n" % (f, getattr(o, f)))
            out('\n')


def action_nodes(sessions):
    for session in sessions:
        fields = ['obj_id', 'node_type_name', 'name', 'pkts_in', 'pkts_out']
        objs = [o for o in session.object_list if isinstance(o, stats.Node)]
        grid_map = mk_grid_map(objs, fields)
        data = grid_snapshot(grid_map)
        fields[0] = '#id'
        print tab.fmt_table([fields] + data)


def action_grid_test(sessions):
    for session in sessions:
        fields = ['obj_id', 'name', 'pkts_in', 'pkts_out']
        objs = [o for o in session.object_list if isinstance(o, stats.Node)]
        grid_map = mk_grid_map(objs, fields)
        fields[0] = '#id'
        data = grid_snapshot(grid_map)
        print tab.fmt_table([fields] + data)


class TableColumn(object):
    def __init__(self, filter_fn, field_name, label, getter_fn):
        self.filter_fn = filter_fn
        self.field_name = field_name
        self.label_suffix = label
        self.getter_fn = getter_fn

    def accept(self, obj):
        return self.filter_fn(obj)

    def getter(self, obj):
        return self.getter_fn(obj, self.field_name)

    def label(self, obj):
        if type(obj) is stats.Vi:
            return '%s-%s' % (obj.interface, self.label_suffix)
        else:
            return '%s-%s' % (obj.name, self.label_suffix)


def periodic_table(sessions, columns, context=None):
    if context is None:
        context = {}
    headers = ['time']
    fields = [mk_date_getter(context)]

    for s in sessions:
        caps = {} # group_name -> [obj, ...]
        for obj in stats.find_objs(s.object_list, fields=['group_name']):
            caps.setdefault(obj.group_name, []).append(obj)
        for cap, objs in sorted(caps.items()):
            for column in columns:
                for obj in filter(column.accept, objs):
                    headers.append(column.label(obj))
                    fields.append(column.getter(obj))

    content_generator = line_generator(headers, fields, context)
    periodic_writer(content_generator, sys.stdout)


def is_vi(obj):
    return type(obj) is stats.Vi


def is_rx_vi(obj):
    return type(obj) is stats.Vi and obj.recv_node_id >= 0


def is_node_type(node_type):
    def myfilter(obj):
        return type(obj) is stats.Node and obj.node_type_name == node_type
    return myfilter


def is_writer(obj):
    return is_node_type('sc_writer')(obj)


def action_line_total(sessions):
    columns = [
        TableColumn(is_rx_vi,  'n_rx_pkts',   'cap-pkts',    mk_n_pkts_getter),
        TableColumn(is_rx_vi,  'n_rx_bytes',  'cap-bytes',   mk_n_bytes_getter),
        TableColumn(is_writer, 'write_bytes', 'write-bytes', mk_n_bytes_getter),
        ]
    periodic_table(sessions, columns)


def action_line_rate(sessions):
    context = {}
    pkts_getter  = lambda o, f: mk_pps_getter(context, o, f)
    bytes_getter = lambda o, f: mk_bw_mbps_getter(context, o, f)
    columns = [
        TableColumn(is_rx_vi,  'n_rx_pkts',   'cap-rate',   pkts_getter),
        TableColumn(is_rx_vi,  'n_rx_bytes',  'cap-mbps',   bytes_getter),
        TableColumn(is_writer, 'write_bytes', 'write-mbps', bytes_getter),
        ]
    periodic_table(sessions, columns, context)


def action_nodes_rate(sessions):
    objs = []
    for session in sessions:
        objs += [o for o in session.object_list if isinstance(o, stats.Node)]
    ctx = dict()
    grid_map = [[mk_objid(o),
                 mk_static_getter(o.name, label='name'),
                 mk_static_getter(o.node_type_name, label='type'),
                 mk_pps_getter(ctx, o, 'pkts_in'),
                 mk_pps_getter(ctx, o, 'pkts_out'),
                 mk_getter(o, 'eos_left')]
                for o in objs]

    col_headings = [g.get_label() for g in grid_map[0]]
    content = table_generator(grid_map, col_headings=col_headings,
                              context=ctx)
    run_curses_app(content, refresh=options.interval)


_special_case_getters = {
    'latency': mk_latency_getter
    }
def action_custom_line_rate(sessions, node_type_name, *fields):
    context = {}
    columns = []
    accept = is_node_type(node_type_name)
    for field in fields:
        if field in _special_case_getters:
            getter = _special_case_getters[field]
        else:
            getter = lambda o, f: mk_pps_getter(context, o, f)
        columns.append( TableColumn(accept, field, field, getter) )
    periodic_table(sessions, columns, context)


def action_idle_test(sessions):
    idles = []
    pkts = []
    grid_map = []
    ctx = dict()
    for session in sessions:
        idles = stats.find_objs(session.object_list, fields=['idle_loops'])
        pkts = stats.find_objs(session.object_list, fields=['pkts_in'])
        grid_map += [[mk_objid(o), mk_rate_getter(o, 'idle_loops', ctx)] \
                         for o in idles]
        grid_map += [[mk_objid(o), mk_rate_getter(o, 'pkts_in', ctx)] \
                         for o in pkts]

    col_headings = [g.get_label() for g in grid_map[0]]
    content = table_generator(grid_map, col_headings=col_headings,
                              context=ctx)
    run_curses_app(content)


def action_hack(sessions):
    assert len(sessions) == 1
    session = sessions[0]
    ctx = dict()
    idles = stats.find_objs(session.object_list, fields=['idle_loops'])
    pkts = stats.find_objs(session.object_list, fields=['pkts_in'])
    grid_map = \
        [[o.obj_id, mk_rate_getter(o, 'idle_loops', ctx)] for o in idles] + \
        [[o.obj_id, mk_rate_getter(o, 'pkts_in', ctx)] for o in pkts]
    content = table_generator(grid_map, context=ctx)
    run_curses_app(content)


def obj_desc_to_getters(sessions, context, obj_desc):
    """Take a string describing a value (or values) that we'd like to
    print.  Return a set of getters that will deliver those values."""

    # IDEAS:
    # - handle wildcards in object and field names (eg. "*.pkts_in")
    # - support arithmetic (eg. "Vi.n_rx_bytes / Vi.n_rx_pkts")

    components = obj_desc.split('.')
    otnton = components.pop(0)
    if otnton == 'date':
        objs = [mk_date_getter(context)]
    elif otnton == 'time':
        objs = [TimeGetter(context)]
    elif otnton == 'elapsed':
        objs = [ElapsedGetter(context)]
    elif otnton == 'interval':
        objs = [IntervalGetter(context)]
    else:
        objs = stats.sessions_find_objs(sessions,
                                        obj_type_or_node_type_or_name=otnton)

    while len(components) > 0:
        c = components.pop(0)
        new_objs = []
        for o in objs:
            if hasattr(o, 'field_names'):
                if c in o.field_names():
                    label = "%s.%s" % (o.name, c)
                    new_objs.append(ValGetter(o, c, label=label))
                else:
                    try:
                        v = getattr(o, c)
                        new_objs.append(v)
                    except AttributeError:
                        pass
            elif c == 'rate':
                new_objs.append(RateGetter(o, context=context))
            elif c == 'pps':
                new_objs.append(FormattedGetter(RateGetter(o, context=context),
                                                fmt_pkts_per_sec))
            elif c == 'mbps':
                new_objs.append(FormattedGetter(RateGetter(o, context=context),
                                                fmt_bw_mbps))
            elif c == 's':
                new_objs.append(FormattedGetter(o, int))
            elif c == 'ms':
                new_objs.append(FormattedGetter(o, lambda x: int(x * 1e3)))
            elif c == 'us':
                new_objs.append(FormattedGetter(o, lambda x: int(x * 1e6)))
            elif c == 'ns':
                new_objs.append(FormattedGetter(o, lambda x: int(x * 1e9)))
            # Otherwise o is dropped.
        objs = new_objs

    # Filter out instances of stats.InstanceBase as these are not getters.
    # (This happens if eg. you just give 'Vi' without a field name).
    new_objs = []
    for o in objs:
        if not isinstance(o, stats.InstanceBase):
            new_objs.append(o)

    return new_objs


def action_custom(sessions, *args):
    all_getters = []
    context = dict()
    for arg in args:
        getters = obj_desc_to_getters(sessions, context, arg)
        if not getters:
            err("WARNING: no objects matching '%s'\n" % arg)
        else:
            all_getters += getters

    updaters = find_updaters(all_getters)
    context['time_now'] = time.time()
    update_fields(updaters)

    for g in all_getters:
        out("%-30s %s\n" % (g.get_label(), str(g())))


def action_custom_lines(sessions, *args):
    all_getters = []
    context = dict()
    for arg in args:
        getters = obj_desc_to_getters(sessions, context, arg)
        if not getters:
            err("WARNING: no objects matching '%s'\n" % arg)
        else:
            all_getters += getters
    if not all_getters:
        err("ERROR: no objects matched\n")
        sys.exit(5)
    headers = [g.get_label() for g in all_getters]
    content_generator = line_generator(headers, all_getters, context)
    periodic_writer(content_generator, sys.stdout)


def action_dot(sessions, *args):
    show_mailboxes = False
    show_free_path = False
    for arg in args:
        if arg == 'mailboxes':
            show_mailboxes = True
        elif arg == 'free_path':
            show_free_path = True
        else:
            usage_err("expected: solar_capture_monitor dot [mailboxes] "
                      "[free_path]")
    def emit(str):
        sys.stdout.write(str + '\n')
    from solar_capture_tools import dot
    for s in sessions:
        dot.session_to_dot(s, emit,
                           show_mailboxes=show_mailboxes,
                           show_free_path=show_free_path)


def action_poke(sessions, *args):
    did_something = False
    for s in sessions:
        for arg in args:
            field_spec, val = arg.split('=', 1)
            getters = obj_desc_to_getters([s], None, field_spec)
            if not getters:
                err("WARNING: no objects match '%s'\n" % field_spec)
            for g in getters:
                try:
                    g.obj.poke(g.field, val)
                    print "%s.%s = %s" % (g.obj, g.field, val)
                    did_something = True
                except:
                    err("ERROR: cannot write attribute '%s.%s'\n" %
                        (g.obj, g.field))

    if not did_something:
        sys.exit(6)

######################################################################

def is_int(s):
    try:
        tmp = int(s)
        return True
    except:
        return False


def main():
    global options
    parser = optparse.OptionParser(usage=usage_text)
    global opt_parser
    opt_parser = parser
    parser.add_option('--running', dest='running', action='store_true',
                      help='Select running sessions (default)', default=False)
    parser.add_option('--stopped', dest='stopped', action='store_true',
                      help='Select stopped sessions', default=False)
    parser.add_option('--all', dest='all', action='store_true',
                      help='Select running and stopped sessions', default=False)
    parser.add_option('--user', dest='user',
                      help='Select sessions owned by this (trusted) user')
    parser.add_option('--interval', dest='interval', type='float', default=1.0,
                      help='Time interval in between output updates')
    parser.add_option('--localtime', dest='localtime', action='store_true',
                      default=False, help='Use local time (default is UTC)')
    parser.add_option('--strftime', dest='strftime', action='store',
                      default=None, help='Specify format string for timestamps')
    parser.add_option('--base-dir', default=None,
                      help='Location of stats directory')
    parser.add_option('--debug', dest='debug', action='store_true',
                      help='Show source of errors', default=False)
    (options, args) = parser.parse_args()

    if options.user:
        uid = get_uid(options.user)
    else:
        uid = os.geteuid()

    # Parse other args.
    session_dirs = []
    action = 'list'

    while args:
        a = args.pop(0)
        if re.match(r'[0-9]+$', a):
            sds = stats.find_session_dirs(base_dir=options.base_dir, pid=a)
            if not sds:
                fail(2, "ERROR: Cannot find session for pid=%s\n" % a)
            session_dirs += sds
        elif re.match(r'[0-9]+/[0-9]+$', a):
            pid, sid = a.split('/')
            sds = stats.find_session_dirs(base_dir=options.base_dir,
                                          pid=pid, session_id=sid)
            if not sds:
                msg = "ERROR: Cannot find session for pid=%s session_id=%s\n"
                fail(2, msg % (pid, sid))
            session_dirs += sds
        elif stats.looks_like_session_dir(a):
            session_dirs.append(a)
        elif a == 'list':
            action = 'list'
            break
        elif ('action_%s' % a) in globals():
            action = a
            break
        elif os.path.isdir(a):
            fail(4, "ERROR: Directory '%s' is not a SolarCapture session\n" % a)
        else:
            fail(1, "ERROR: I do not understand '%s'\n" % a)

    if not session_dirs:
        session_dirs = stats.find_session_dirs(base_dir=options.base_dir,
                                               uid=uid)
        infos, bad_sds = stats.get_session_infos(session_dirs)
        # ?? do we want to log bad_sds?
        infos_get_is_running(infos)
        infos = infos_filter_by_running(infos, running=options.running,
                                        stopped=options.stopped,
                                        all=options.all)
    else:
        session_dirs = list(set(session_dirs))  # kill any dups
        if uid is not None:
            requested_dirs = session_dirs
            session_dirs = [sd for sd in session_dirs
                            if stats.uid_can_access_session_dir(sd, uid)]
            removed = set(requested_dirs) - set(session_dirs)
            if removed:
                err("ERROR: Not loading these sessions owned by other user:\n")
                for sd in removed:
                    err("ERROR:   %s\n" % sd)
                sys.exit(5)
        infos, bad_sds = stats.get_session_infos(session_dirs)
        infos_get_is_running(infos)

    if action == 'list':
        do_list(infos, strm=sys.stdout)
        sys.exit(0)

    sessions = []
    for inf in infos:
        try:
            sessions.append(stats.Session(inf['dir']))
        except:
            err("ERROR: Failed to load session '%s'\n" % inf['dir'])
            if options.debug:
                raise

    if sessions:
        action_fn = globals()['action_%s' % action]
        action_fn(sessions, *args)
    else:
        sys.stderr.write("No solar_capture sessions found\n")
    sys.exit(0)


if __name__ == '__main__':
    try:
        main()
    except IOError:
        # 'watch'ing solar_capture_monitor can throw IOErrors as it closes
        # its end of the pipe after getting a page worth of output
        ex = sys.exc_info()
        if ex[1].errno != errno.EPIPE:
            raise
