#! /usr/bin/python2
# SPDX-License-Identifier: BSD-2-Clause
# X-SPDX-Copyright-Text: (c) Solarflare Communications Inc

import os
import sys
import socket
import select
import subprocess


def run_cmd(cmd):
    child = subprocess.Popen(cmd, stdout=subprocess.PIPE)
    stdout = child.communicate()[0]
    assert child.returncode == 0, child.returncode
    return stdout


def run_orm_json():
    orm_json = os.path.join(
        os.path.dirname(sys.argv[0]), '../../..', 'build', 'gnu_x86_64',
        'tools', 'onload_remote_monitor', 'orm_json')
    assert os.path.exists(orm_json), orm_json
    return run_cmd(orm_json)


def check_orm_json():
    orm_json = os.path.join(
        os.path.dirname(sys.argv[0]), '../../..', 'build', 'gnu_x86_64',
        'tools', 'onload_remote_monitor', 'orm_json')
    if not os.path.exists(orm_json):
        print 'orm_json not built.'
        sys.exit(1)


class Server(object):
    def __init__(self, port):
        listen_backlog = 10
        self.lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.lsock.bind(('', port))
        self.lsock.listen(listen_backlog)
        print 'Server: Listening on %d' % (port)
        self.clients = {} # sock -> buffered_data

    def handle_bad_req(self, sock, req):
        print 'Client %d: bad request %r' % (sock.fileno(), req)
        sock.send('%r: Bad request\n' % (req))

    def stack_state_get_req(self, sock):
        sock.send('%s\n' % run_orm_json())

    def accept_connection(self):
        new_sock, _ = self.lsock.accept()
        self.clients[new_sock] = ''

    def remove_client(self, sock):
        del self.clients[sock]

    def handle_request(self, sock):
        if '\n' not in self.clients[sock]:
            return
        req, self.clients[sock] = self.clients[sock].split('\n', 1)
        handlers = {
            'stack_state_get': self.stack_state_get_req,
            }
        try:
            handlers[req](sock)
        except KeyError:
            self.handle_bad_req(sock, req)

    def do_recv(self, sock):
        assert sock in self.clients.keys(), 'unknown socket %d' % sock.fileno()
        try:
            data = sock.recv(1024)
        except socket.error:
            self.remove_client(sock)
            return
        if not data:
            self.remove_client(sock)
        else:
            self.clients[sock] += data
            self.handle_request(sock)

    def loop(self):
        try:
            while True:
                socks = [self.lsock] + self.clients.keys()
                readable, _, _ = select.select(socks, [], [])
                for sock in readable:
                    if sock == self.lsock:
                        self.accept_connection()
                    else:
                        self.do_recv(sock)
        except KeyboardInterrupt:
            print
            sys.exit()


def usage():
    print 'Usage: %s port' % sys.argv[0]
    sys.exit(1)


def main():
    if len(sys.argv) != 2:
        usage()
    check_orm_json()
    port = int(sys.argv[1])
    Server(port).loop()


if __name__ == '__main__':
    main()
