#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0
# X-SPDX-Copyright-Text: (c) Copyright 2010-2019 Xilinx, Inc.

# Set affinity for interrupts.

import os, sys
import sfcaffinity as affinity
import sfcmask as mask


me = os.path.basename(sys.argv[0])

opt_parser = None
options = None

######################################################################

def top_and_tail(x, pre="", post=""):
    if type(x) is list:
        return [pre + str(a) + post for a in x]
    else:
        return pre + str(x) + post


def out(x, file=sys.stdout):
    x = top_and_tail(x, post="\n")
    if type(x) is list:
        x = ''.join(x)
    try:
        file.write(x)
    except:
        sys.stderr.write("ERROR: repr(x) = %s\n" % repr(x))
        raise


def log(x, file=sys.stdout, level=1):
    if options.loglevel >= level:
        x = top_and_tail(x, "%s: " % me)
        out(x, file=file)


def err(x):
    x = top_and_tail(x, "ERROR: ")
    log(x, file=sys.stderr, level=0)


warnings_emitted = []

def warn(x):
    x = top_and_tail(x, "WARNING: ")
    if x in warnings_emitted:
        return
    warnings_emitted.append(x)
    log(x, file=sys.stderr, level=1)


def fail(x, exit_code=1):
    err(x)
    sys.exit(exit_code)

######################################################################

def spread_interrupts_over_cores(irq_names, cores):
    n2v = affinity.irq_get_name2vec_map()
    for i, irq_name in enumerate(irq_names):
        core = cores[i % len(cores)]
        vec = n2v[irq_name]
        log("%s(irq=%d) => core(%d)" % (irq_name, vec, int(core)))
        affinity.irq_set_affinity(vec, 1 << int(core))


def cores_spread_over(topology, xs):
    for x in xs:
        x.cores_left = list(x.core_ids)
    cores = []
    while max([len(x.cores_left) for x in xs]):
        x = xs[0]
        xs = xs[1:]
        xs.append(x)
        if x.cores_left:
            c = x.cores_left[0]
            del(x.cores_left[0])
            cores.append(c)
    return cores


def do_irqpattern_corespec(irq_name_pattern, cores_args):
    irq_names = affinity.irq_names_matching_pattern(irq_name_pattern)
    n_irqs = len(irq_names)
    if n_irqs == 0:
        fail("No IRQs matched pattern '%s'" % (irq_name_pattern))

    topology = affinity.get_topology()

    try:
        if options.cores:
            opt = 'cores'
            topology = affinity.get_topology(cores=options.cores)
        if options.packages:
            opt = 'packages'
            packages = mask.to_int_list(options.packages)
            cores = [c.id for c in topology.get_cores()
                     if c.pkg.id in packages]
            topology = affinity.get_topology(cores=cores)
    except affinity.NoCores:
        fail("No cores left after applying filters.")
    except affinity.BadCoreId:
        inst = sys.exc_info()[1]
        fail("Bad core id %d" % inst.core_id)
    except mask.BadMask:
        inst = sys.exc_info()[1]
        fail("Bad mask in --%s=%s" % (opt, inst.msg))

    caches = topology.get_top_level_shared_caches()

    if len(cores_args) == 0:
        if n_irqs == len(topology.real_cores):
            log("One interrupt per real core")
            cores = topology.real_core_ids
        elif n_irqs == len(topology.get_cores()):
            log("One interrupt per hyperthread")
            cores = topology.core_ids
        elif n_irqs == len(caches):
            log("One interrupt per shared cache")
            cores = cores_spread_over(topology, caches)
        elif n_irqs == len(topology.get_packages()):
            log("One interrupt per package")
            cores = [p.core_ids[0] for p in topology.get_packages()]
        elif caches:
            log("Spreading interrupts evenly over %d shared caches" %
                len(caches))
            cores = cores_spread_over(topology, caches)
        else:
            log("Spreading interrupts evenly over %d packages" %
                len(topology.get_packages()))
            cores = cores_spread_over(topology, topology.get_packages())
    elif len(cores_args) == 1:
        cores = mask.to_int_list(cores_args[0])
    else:
        cores = [int(x) for x in cores_args]

    if len(cores) == 0:
        fail("Set of cores is empty!")

    spread_interrupts_over_cores(irq_names, cores)


def do_rfc2544(args):
    # TODO
    usage()


def usage(msg=None, exit_code=1, file=sys.stderr):
    if msg:
        file.write('\n')
        log(msg, file=file, level=0)
        file.write('\n')
    global opt_parser
    opt_parser.print_help(file=file)
    sys.exit(exit_code)


def misc_setup():
    def log_system(cmd):
        log('SYSTEM: %s' % cmd)
    affinity.log_system = log_system


def main2(args):
    if len(args) < 1:
        usage()

    misc_setup()

    if args[0] == 'rfc2544':
        do_rfc2544(args[1:])
    else:
        do_irqpattern_corespec(args[0], args[1:])


def main():
    import optparse
    usage_str = "%prog <irq-match> [cores]"
    irq_str = ""
    irq_str += "Required:\n"
    irq_str += "  <irq-match>           Either an interface name (eg eth4) or an RSS queue\n"
    irq_str += "                        identifier (eg eth4-0).\n"
    irq_str += "                        The number of RSS queues assigned to an interface\n"
    irq_str += "                        can be found by inspection of /proc/interrupts or\n"
    irq_str += "                        by reference to the sfc module parameter rss_cpus."

    op = optparse.OptionParser(usage=usage_str + "\n\n" + irq_str)

    op.add_option("-p", "--packages", dest='packages',
                  action="store",
                  help="Restrict set of cores to given package(s)")
    op.add_option("-c", "--cores", dest='cores',
                  action="store",
                  help="Restrict set of cores")

    # TODO: understand numa node topology
    #op.add_option("-n", "--nodes", dest='nodes',
    #              action="store",
    #              help="Restrict set of cores to given numa node(s)")

    op.add_option("-l", "--loglevel", dest='loglevel',
                  action="store", type='int', default=2,
                  help="Set log level")
    op.add_option("-v", "--verbose", dest='loglevel',
                  action="count", help="More verbose output")
    op.add_option("-q", "--quiet", dest='loglevel',
                  action="store_const", const=0, help="Quiet")

    global opt_parser, options
    opt_parser = op
    options, args = op.parse_args()
    return main2(args)


if __name__ == '__main__':
    rc = main()
    sys.exit(rc)
