#!/usr/libexec/platform-python
# SPDX-License-Identifier: GPL-2.0-only
# Copyright (C) 2026, Advanced Micro Devices, Inc.

import argparse
import subprocess
import os
import re
import json
import sys
from pathlib import Path

VERSION = "0.8.1"
MANIFEST_PATH = "/usr/share/doc/amdsolarflare-firmware/manifest.json"
EXIT_SUCCESS = 0
EXIT_NEEDS_UPDATE = 1
EXIT_SFC_UNSUITABLE = 2
EXIT_INVALID_COMMAND = 3
EXIT_RUNTIME_ERROR = 4
pci_to_name = {}
name_to_pci = {}

def print_verbose(*args):
    if options.verbose:
        print(*args)

def print_normal(*args):
    if not options.silent:
        print(*args)

def init_pci_name_mappings(interfaces):
    for interface in interfaces:
        q = Path("/sys/class/net/" + interface)
        try:
            name = q.readlink().name
            pci = q.readlink().parent.parent.name
        except AttributeError:
            # Workaround for Python <= 3.8
            name = q._from_parts((os.readlink(q),)).name
            pci = q._from_parts((os.readlink(q),)).parent.parent.name
        if re.match("[0-9a-f:.]+", pci, re.IGNORECASE):
            pci_to_name[pci] = name
            name_to_pci[name] = pci
        else:
            print_verbose(f"interface {name} is not PCI?")

def run_command(cmd, requires_root=False, changes_nic=False, capture_stdout=False):
    if requires_root and os.getegid() != 0 and options.sudo:
        cmd = "sudo " + cmd

    if changes_nic and options.dry_run:
        print_normal(f"Would run: `{cmd}`")
        return ""

    if capture_stdout:
        ret = subprocess.run(cmd, shell=True, check=False, stdout=subprocess.PIPE)
    else:
        ret = subprocess.run(cmd, shell=True, check=False)
    if ret.returncode != 0:
        print_normal(f"`{cmd}` failed with rc={ret.returncode:d}.")
        sys.exit(EXIT_RUNTIME_ERROR)
    if capture_stdout:
        return ret.stdout
    return ""

def get_sfc_interfaces():
    interfaces = []
    p = Path("/sys/module/sfc/drivers/pci:sfc/")
    for q in p.glob("*/net/*"):
        interfaces.append(q.name)
    interfaces.sort()
    return interfaces

def get_nic_pci_details(interface):
    details = {}
    for key in ["device", "vendor", "subsystem_device", "subsystem_vendor"]:
        p = Path("/sys/class/net/" + interface + "/device/" + key)
        details[key] = p.read_text(encoding="utf-8").strip().lower()
    return details

def get_files_matching_interface(interface):
    nic_pci_details = get_nic_pci_details(interface)
    files = []
    for file in manifest["files"]:
        if "pci_match" in file:
            if nic_pci_details == file["pci_match"]:
                files.append(file)
        elif "pci_match_any" in file:
            if nic_pci_details in file["pci_match_any"]:
                files.append(file)
        else:
            print_verbose(f"Skipping {file['path']} as no supported match criterea")
    return files

def sfc_driver_suitable():
    p = Path("/sys/module/sfc/version")
    if not p.exists():
        print_normal("sfc driver is in-tree or too old")
        return False
    vers = p.read_text(encoding="utf-8").strip().split(".")
    if len(vers) != 4:
        print_normal("sfc driver version is wrong format for out-of-tree driver")
        return False
    if int(vers[0]) < 6:
        print_normal("sfc driver is too old")
        return False
    # Should be sfc driver version 6.x.y.z or newer
    return True

def pci_match_to_str(match):
    return (
        match["vendor"]
        + ":"
        + match["device"]
        + "-"
        + match["subsystem_vendor"]
        + ":"
        + match["subsystem_device"]
    )

parser = argparse.ArgumentParser()
group1 = parser.add_mutually_exclusive_group()
group1.add_argument(
    "-V", "--version", help="Display version and exit.", action="store_true"
)
group1.add_argument("-v", "--verbose", help="Verbose mode.", action="store_true")
group1.add_argument(
    "-s", "--silent", help="Silent mode, output errors only.", action="store_true"
)
parser.add_argument(
    "-i",
    "--adapter",
    metavar="INTERFACE",
    help="Operate only on the adapter with the specified interface name. Note that the configuration is for the entire adapter, not just one interface.",
)
parser.add_argument(
    "--config",
    metavar="FILE",
    help="Apply the specified configuration file (from /lib/firmware/) to the adapters.",
)
parser.add_argument(
    "--sudo",
    help="When running script as a non-root user, use 'sudo' for any commands which require it.",
    action="store_true",
)
parser.add_argument(
    "--list",
    help="Show the adapter list.",
    action="store_true",
)
parser.add_argument(
    "--dry-run",
    help="Don't run commands which could change NIC state - just print them",
    action="store_true",
)
options = parser.parse_args()

print_normal(f"AMD Solarflare configuration utility: v{VERSION}\n")
if options.version:
    sys.exit(EXIT_SUCCESS)

if not sfc_driver_suitable():
    sys.exit(EXIT_SFC_UNSUITABLE)

interfaces = get_sfc_interfaces()
init_pci_name_mappings(interfaces)

if options.adapter:
    if options.adapter in interfaces:
        interfaces = [options.adapter]
    else:
        print_normal(f"Interface {options.adapter} is not an AMD Solarflare adapter")
        sys.exit(EXIT_INVALID_COMMAND)

if options.list:
    print(f"Loading manifest {MANIFEST_PATH}")
else:
    # Anything other than --list will require sudo for some commands
    if os.getegid() != 0 and not options.sudo:
        print_normal("Configuration read/write requires root permission,")
        print_normal("but currently running as non-root user.\n")
        print_normal("Use '--sudo' option, or run script as root.")
        sys.exit(EXIT_INVALID_COMMAND)

with open(MANIFEST_PATH, encoding="utf-8") as fh:
    manifest = json.load(fh)

needs_power_cycle = False
for interface in interfaces:
    pci = f"pci/{name_to_pci[interface]}"

    files = get_files_matching_interface(interface)
    if len(files) == 0:
        print_verbose(f"Skipping interface {interface} - no matching files in manifest")
        continue

    if options.list:
        print(f"{interface:20} {pci}")
        continue

    # Only want to use PF0 for config operations
    if not re.search(r"\.0$", pci):
        pf0 = re.sub(r"\.[0-9a-f]+$", ".0", name_to_pci[interface], flags=re.IGNORECASE)
        if not options.adapter:
            print_normal(
                f"{interface}: skipping, configuration shared with"
                f" {pci_to_name[pf0]}"
            )
            continue
        pci = f"pci/{pf0}"

    if options.config:
        print_verbose(f"{interface}: old stored config was:")
        if options.verbose:
            run_command(
                f"devlink health diagnose {pci} reporter nvcfg-stored",
                requires_root=True,
            )
        print_normal(f"{interface}: writing new config")
        cmd = f"devlink dev flash {pci} file {options.config}"
        run_command(
            cmd,
            requires_root=True,
            changes_nic=True,
            capture_stdout=options.silent,
        )
        needs_reset = True

    nvcfg_next = run_command(
        f"devlink health diagnose {pci} reporter nvcfg-next",
        requires_root=True,
        capture_stdout=True,
    )
    print_normal(f"Stored configuration for {interface}:")
    print_normal(nvcfg_next.decode("utf-8"))
    nvcfg_active = run_command(
        f"devlink health diagnose {pci} reporter nvcfg-active",
        requires_root=True,
        capture_stdout=True,
    )
    if nvcfg_active != nvcfg_next:
        needs_power_cycle = True

if needs_power_cycle:
    print_normal("Power cycle the system to apply configuration changes")
    sys.exit(EXIT_NEEDS_UPDATE)
sys.exit(EXIT_SUCCESS)
