#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0
# X-SPDX-Copyright-Text: (c) Copyright 2013-2019 Xilinx, Inc.

#****************************************************************************
# Copyright (c) 2013, Solarflare Communications Inc,
#
# Maintained by Solarflare Communications
#  <onload-dev@solarflare.com>
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 as published
# by the Free Software Foundation, incorporated herein by reference.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
#****************************************************************************

import os, shutil, sys, optparse, socket, select, stat, signal, pwd, errno

import platform
onload_build = 'gnu_%s' % platform.machine()
path = os.path.join(os.path.dirname(sys.argv[0]), '../../..', 'build',
                    onload_build, 'tools', 'solar_clusterd')
if os.path.exists(path):
    sys.path.append(path)
    import daemonize, parse_config, cluster_protocol as cp
else:
    import solar_clusterd.daemonize as daemonize, \
        solar_clusterd.parse_config as parse_config, \
        solar_clusterd.cluster_protocol as cp

if os.environ.get('CLUSTERD_PROTOCOL_VERSION'):
    CLUSTERD_PROTOCOL_VERSION = int( os.environ['CLUSTERD_PROTOCOL_VERSION'] )
else:
    CLUSTERD_PROTOCOL_VERSION = cp.CLUSTERD_PROTOCOL_VERSION

PR_SET_NAME = 15 # from linux/prctl.h


# TODO: Use standard python for sendfd


class Cluster(object):
    def __init__(self, driver_fd, pd_id, vi_id, protectionmode, intf_name):
        self.driver_fd = driver_fd
        self.pd_id = pd_id
        self.vi_id = vi_id
        self.protectionmode = protectionmode
        self.intf_name = intf_name


class Server(object):
    def __init__(self, options, config):
        super(Server, self).__init__()
        self.options = options
        self.config = config
        self.listen_sock = None
        self.clients = {} # sock -> buffered_data
        self.clusters = {} # cluster_name -> Cluster

    def log_info(self, msg):
        if self.logger:
            self.logger.warn(msg)
        else:
            sys.stdout.write(msg)
            sys.stdout.flush()

    def log_warn(self, msg):
        if self.logger:
            self.logger.warn(msg)
        else:
            sys.stdout.write(msg)
            sys.stdout.flush()

    def log_error(self, msg):
        if self.logger:
            self.logger.error(msg)
        else:
            sys.stderr.write(msg)
            sys.stderr.flush()

    def init_vis(self):
        for name, cluster in self.config.clusters.items():
            n_vis = cluster['numchannels']
            intf = cluster['captureinterface']
            protectionmode = cluster['protectionmode']
            sys.stdout.write('Cluster %s: %s, %d channels, %s\n' % (
                    name, intf, n_vis, protectionmode))

            if not protectionmode.startswith('EF_PD_'):
                raise SyntaxError("Cluster '%s': invalid protectionmode: '%s'" %
                                  (name, protectionmode))
            try:
                protectionmode = getattr(cp, cluster['protectionmode'])
            except AttributeError:
                raise AttributeError(
                    "Cluster '%s': invalid protectionmode: '%s'" % (
                        name, protectionmode))

            driver_fd = cp.open_driver()
            cp_vi_index, pd_id, vi_id = cp.vi_set_alloc(driver_fd, intf, n_vis,
                                                  protectionmode)
            for streams in cluster['streams'].values():
                for stream in streams.capturestream:
                    cp.vi_set_add_stream(cp_vi_index, stream)
            self.clusters[name] = Cluster(driver_fd, pd_id, vi_id,
                                          protectionmode, intf)


    def run(self):
        for signum in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP,
                       signal.SIGUSR1, signal.SIGUSR2]:
            signal.signal(signum, self.on_exit)
        self.listen_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self.listen_sock.bind(options.socket)
        os.chmod(options.socket, stat.S_IRUSR | stat.S_IWUSR)
        self.listen_sock.listen(5)
        self.log_info('Listening on %s\n' % options.socket)
        self.main_loop()


    def main_loop(self):
        while True:
            select_fds = [self.listen_sock] + list(self.clients.keys())
            readable, _, _ = select.select(select_fds, [], [])
            for fd in readable:
                if fd is self.listen_sock:
                    self.accept(fd)
                elif fd in self.clients:
                    self.do_recv(fd)
                else:
                    assert 0, 'unknown fd %d' % fd.fileno()


    def accept(self, sock):
        new_sock, _ = sock.accept()
        self.log_info('Client %d: connected\n' % new_sock.fileno())
        self.clients[new_sock] = ''


    def do_recv(self, sock):
        try:
            data = sock.recv(1024)
            if not data:
                raise socket.error('EOF')
        except socket.error:
            self.log_info('Client %d: disconnected\n' % sock.fileno())
            del self.clients[sock]
        else:
            self.clients[sock] += data
            self.handle_request(sock)


    def handle_request(self, sock):
        if '\n' in self.clients[sock]:
            req, self.clients[sock] = self.clients[sock].split('\n', 1)
            if ' ' in req:
                req_id, payload = req.split(' ', 1)
            else:
                req_id = req
                payload = ''

            if req_id.isdigit():
                req_id = int(req_id)

            handlers = {cp.CLUSTERD_VERSION_REQ: self.handle_version_req,
                        cp.CLUSTERD_ALLOC_CLUSTER_REQ:
                            self.handle_alloc_cluster_req,}
            try:
                handlers[req_id](sock, req_id, payload)
            except KeyError:
                self.handle_bad_req(sock, req_id, req)


    def handle_bad_req(self, sock, req_id, req):
        self.log_warn('Client %d: bad request %r\n' % (sock.fileno(), req))
        sock.send('%d\n' % (cp.CLUSTERD_ERR_BAD_REQUEST))


    def handle_version_req(self, sock, req_id, payload):
        ver = payload
        self.log_info('Client %d: version request %r ' % (
                sock.fileno(), ver))
        if ver.isdigit() and \
                int(ver) == CLUSTERD_PROTOCOL_VERSION:
            result = cp.CLUSTERD_ERR_SUCCESS
            self.log_info('(version match)\n')
        else:
            result = cp.CLUSTERD_ERR_FAIL
            self.log_info('(version mismatch)\n')
        sock.send('%d %d\n' % (cp.CLUSTERD_VERSION_RESP, result))


    def handle_alloc_cluster_req(self, sock, req_id, payload):
        self.log_info('Client %d: cluster request %r ' % (
                sock.fileno(), payload))
        (name, requested_pd_flags) = payload.split()
        requested_pd_flags = int(requested_pd_flags)

        if name not in self.clusters.keys():
            self.log_warn('(%s: no such cluster)\n' % name)
            sock.send('%d %d %d %d %d\n' % (
                    cp.CLUSTERD_ALLOC_CLUSTER_RESP, cp.CLUSTERD_ERR_FAIL,
                    errno.ENOENT, 0, 0))
            return

        cluster = self.clusters[name]
        if requested_pd_flags != cluster.protectionmode:
            self.log_warn('(PD flags match fail req=%d cluster=%d)\n' % (
                        requested_pd_flags, cluster.protectionmode))
            sock.send('%d %d %d %d %d\n' % (
                    cp.CLUSTERD_ALLOC_CLUSTER_RESP, cp.CLUSTERD_ERR_FAIL,
                    errno.EINVAL, 0, 0))
            return

        self.log_info('Pass: fd=%d, pd=%d, vi_set=%d, pd_flags=%d\n' % (
                    cluster.driver_fd, cluster.pd_id, cluster.vi_id,
                    cluster.protectionmode))
        cp.sendfd(sock.fileno(), cluster.driver_fd, '%d %d %d %d %s\n' % (
                cp.CLUSTERD_ALLOC_CLUSTER_RESP, cp.CLUSTERD_ERR_SUCCESS,
                cluster.pd_id, cluster.vi_id, cluster.intf_name))


    def on_exit(self, signum, frame):
        self.log_warn('Received signal %s.  Exiting\n' % signum)
        shutil.rmtree(self.options.directory)
        sys.exit(0)


def main(config_file, options):
    config = parse_config.Config(config_file)
    sys.stdout.write('solar_clusterd version: %s\n' % cp.onload_version)
    server = Server(options, config)
    server.init_vis()
    if not options.foreground:
        server.logger = daemonize.daemonize(options.directory,
                                            user=options.user,
                                            group=options.group, verbose=True)
    else:
        if os.path.exists(options.directory):
            sys.stderr.write('ERROR: directory %s already exists.  Either '
                             'another instance running or previous instance '
                             'did not clean up properly.  If no other '
                             'instance is running, please manually remove the '
                             'directory\n' % options.directory)
            sys.exit(1)
        os.makedirs(options.directory)
        server.logger = None
    server.run()


def parse_cmdline():
    sock_path = os.path.join(cp.DEFAULT_CLUSTERD_DIR + '<username>',
                             cp.DEFAULT_CLUSTERD_SOCK_NAME)

    usage = 'usage: %prog [options] config-file'
    parser = optparse.OptionParser(
        version=cp.onload_version, usage=usage)
    parser.add_option('-l', '--logfile', help='This option is deprecated and '
                      'only present for backwards compatibility.  We log via '
                      'syslogd(8) now.  Setting this is a NOP.')
    parser.add_option('-p', '--pidfile', help='This option is deprecated and '
                      'only present for backwards compatibility.  We use the '
                      'directory for the socket file as a lock now.  Setting '
                      'this is a NOP.')
    parser.add_option('-s', '--socket', help='Specify path to socket file '
                      'Default is ' + sock_path + '.  If this is changed from '
                      'the default, then solar_capture must be invoked with '
                      'environment variable EF_VI_CLUSTER_SOCKET set to the '
                      'new path.')
    parser.add_option('-f', '--foreground', help='Do not daemonize',
                      action='store_true')
    parser.add_option('-u', '--user', help='Drop privileges to this '
                      'user after daemonizing', metavar='USER[:GROUP]')
    options, args = parser.parse_args()

    if len(args) != 1:
        parser.print_usage()
        sys.exit(1)
    config_file = args[0]

    options.group = None
    if options.user:
        if options.foreground:
            sys.stderr.write(
                'ERROR: --user is not supported in foreground mode\n')
            sys.exit(1)
        elif os.getuid() != 0:
            sys.stderr.write(
                'ERROR: --user only supported when running as root\n')
            sys.exit(1)

        if ':' in options.user:
            options.user, options.group = options.user.split(':', 1)

    if options.socket is None:
        username = options.user or pwd.getpwuid(os.getuid()).pw_name
        options.directory = cp.DEFAULT_CLUSTERD_DIR + username
        options.socket = os.path.join(options.directory,
                                      cp.DEFAULT_CLUSTERD_SOCK_NAME)
    else:
        options.directory = os.path.dirname(options.socket)
    return config_file, options


if __name__ == '__main__':
    config_file, options = parse_cmdline()
    main(config_file, options)
