#! /usr/bin/python3
# SPDX-License-Identifier: GPL-2.0-only
# Copyright (C) 2026, Advanced Micro Devices, Inc.

import argparse
import subprocess
import os
import re
import json
import sys
from pathlib import Path

VERSION = "0.8.1"
MANIFEST_PATH = "/usr/share/doc/amdsolarflare-firmware/manifest.json"
EXIT_SUCCESS = 0
EXIT_NEEDS_UPDATE = 1
EXIT_SFC_UNSUITABLE = 2
EXIT_INVALID_COMMAND = 3
EXIT_RUNTIME_ERROR = 4
pci_to_name = {}
name_to_pci = {}


def print_verbose(*args):
    if options.verbose:
        print(*args)


def print_normal(*args):
    if not options.silent:
        print(*args)


def version_gt(version_a, version_b):
    for component_a, component_b in zip(version_a.split("."), version_b.split(".")):
        if int(component_a) > int(component_b):
            return True
        if int(component_a) < int(component_b):
            return False
    return False


def init_pci_name_mappings(interfaces):
    for interface in interfaces:
        q = Path("/sys/class/net/" + interface)
        try:
            name = q.readlink().name
            pci = q.readlink().parent.parent.name
        except AttributeError:
            # Workaround for Python <= 3.8
            name = q._from_parts((os.readlink(q),)).name
            pci = q._from_parts((os.readlink(q),)).parent.parent.name
        if re.match("[0-9a-f:.]+", pci, re.IGNORECASE):
            pci_to_name[pci] = name
            name_to_pci[name] = pci
        else:
            print_verbose(f"interface {name} is not PCI?")


def run_command(cmd, requires_root=False, changes_nic=False, capture_stdout=False):
    if requires_root and os.getegid() != 0 and options.sudo:
        cmd = "sudo " + cmd

    if changes_nic and options.dry_run:
        print_normal(f"Would run: `{cmd}`")
        return ""

    if capture_stdout:
        ret = subprocess.run(cmd, shell=True, check=False, stdout=subprocess.PIPE)
    else:
        ret = subprocess.run(cmd, shell=True, check=False)
    if ret.returncode != 0:
        print_normal(f"`{cmd}` failed with rc={ret.returncode:d}.")
        sys.exit(EXIT_RUNTIME_ERROR)
    if capture_stdout:
        return ret.stdout
    return ""


def get_sfc_interfaces():
    interfaces = []
    p = Path("/sys/module/sfc/drivers/pci:sfc/")
    for q in p.glob("*/net/*"):
        interfaces.append(q.name)
    interfaces.sort()
    return interfaces


def get_nic_pci_details(interface):
    details = {}
    for key in ["device", "vendor", "subsystem_device", "subsystem_vendor"]:
        p = Path("/sys/class/net/" + interface + "/device/" + key)
        details[key] = p.read_text(encoding="utf-8").strip().lower()
    return details


def get_files_matching_interface(interface):
    nic_pci_details = get_nic_pci_details(interface)
    files = []
    for file in manifest["files"]:
        if "pci_match" in file:
            if nic_pci_details == file["pci_match"]:
                files.append(file)
        elif "pci_match_any" in file:
            if nic_pci_details in file["pci_match_any"]:
                files.append(file)
        else:
            print_verbose(f"Skipping {file['path']} as no supported match" " criterea")
    return files


def sfc_driver_suitable():
    p = Path("/sys/module/sfc/version")
    if not p.exists():
        print_normal("sfc driver is in-tree or too old")
        return False
    vers = p.read_text(encoding="utf-8").strip().split(".")
    if len(vers) != 4:
        print_normal("sfc driver version is wrong format for out-of-tree driver")
        return False
    if int(vers[0]) < 6:
        print_normal("sfc driver is too old")
        return False
    # Should be sfc driver version 6.x.y.z or newer
    return True


def pci_match_to_str(match):
    return (
        match["vendor"]
        + ":"
        + match["device"]
        + "-"
        + match["subsystem_vendor"]
        + ":"
        + match["subsystem_device"]
    )


parser = argparse.ArgumentParser()
group1 = parser.add_mutually_exclusive_group()
group1.add_argument(
    "-V", "--version", help="Display version and exit.", action="store_true"
)
group1.add_argument("-v", "--verbose", help="Verbose mode.", action="store_true")
group1.add_argument(
    "-s", "--silent", help="Silent mode, output errors only.", action="store_true"
)
parser.add_argument("-i", "--adapter", help="Operate only on the specified adapter.")
parser.add_argument(
    "--write", help="Write updated firmware to the adapter(s).", action="store_true"
)
parser.add_argument(
    "--force",
    help="Force update of all firmware, even if the"
    " installed firmware version is the same or more recent.",
    action="store_true",
)
parser.add_argument(
    "-y",
    "--yes",
    help="Update without prompting for a final"
    " confirmation. This option may be used with the --write"
    " and --force options.",
    action="store_true",
)
parser.add_argument(
    "--sudo",
    help="When running script as a non-root user, use 'sudo' for any commands which require it.",
    action="store_true",
)
parser.add_argument(
    "--list",
    help="Show the adapter list.",
    action="store_true",
)
parser.add_argument(
    "-M",
    "--manifest",
    help="Show details of available firmware images and exit.",
    action="store_true",
)
parser.add_argument(
    "--dry-run",
    help="Don't run commands which could change NIC state - just print them",
    action="store_true",
)
options = parser.parse_args()

print_normal(f"AMD Solarflare firmware update utility: v{VERSION}\n")
if options.version:
    sys.exit(EXIT_SUCCESS)

if not sfc_driver_suitable():
    sys.exit(EXIT_SFC_UNSUITABLE)

interfaces = get_sfc_interfaces()
init_pci_name_mappings(interfaces)

if options.adapter:
    if options.adapter in interfaces:
        interfaces = [options.adapter]
    else:
        print_normal(f"Interface {options.adapter} is not an AMD Solarflare" " adapter")
        sys.exit(EXIT_INVALID_COMMAND)

if options.write and (os.getegid() != 0 and not options.sudo):
    print_normal("Firmware update requires root permission,")
    print_normal("but currently running as non-root user.\n")
    print_normal("Use '--sudo' option, or run script as root.")
    sys.exit(EXIT_INVALID_COMMAND)

if options.list or options.manifest:
    print(f"Loading manifest {MANIFEST_PATH}")

with open(MANIFEST_PATH, encoding="utf-8") as fh:
    manifest = json.load(fh)

if options.manifest:
    for file in manifest["files"]:
        print(f"type: {file['type']:20} version: {file['version']:20}")
        print(f"  path: {file['path']:20}")
        if "pci_match" in file:
            print(f"  pci_match: {pci_match_to_str(file['pci_match'])}")
        elif "pci_match_any" in file:
            output = ""
            for pci_match in file["pci_match_any"]:
                if output != "":
                    output += ", "
                output += pci_match_to_str(pci_match)
            print(f"  pci_match_any: {output}")
    sys.exit(EXIT_SUCCESS)

status = EXIT_SUCCESS
for interface in interfaces:
    pci = "pci/" + name_to_pci[interface]

    files = get_files_matching_interface(interface)
    if len(files) == 0:
        print(
            f"Skipping interface {interface} - no matching files" " in manifest"
        )
        continue

    if options.list:
        print(f"{interface:20} {pci}")
        continue

    output = run_command("devlink -j dev info " + pci, capture_stdout=True)
    j = json.loads(output)
    print_normal("%s %s" % (interface, pci))
    # Print versions of components we might update
    for file in files:
        name = file.get("display_type") or file["type"]
        running_ver = j["info"][pci]["versions"]["running"].get(file["type"])
        stored_ver = j["info"][pci]["versions"]["stored"].get(file["type"])
        nic_ver = running_ver or stored_ver
        print_normal(f"{name:20}:       v{nic_ver}")

    needs_reset = False
    prompt_user = not (options.yes or options.silent)

    for file in files:
        name = file.get("display_type") or file["type"]
        running_ver = j["info"][pci]["versions"]["running"].get(file["type"])
        stored_ver = j["info"][pci]["versions"]["stored"].get(file["type"])
        nic_ver = running_ver or stored_ver
        filename = file["path"]
        file_ver = file["version"]
        print_verbose(
            f"Comparing {name} version {file_ver} filename {filename} with"
            f" {nic_ver}"
        )
        newer = version_gt(file_ver, nic_ver)
        if options.write and (newer or options.force):
            # Unless interface explictly specified, only apply updates to PF0
            if not re.search(r"\.0$", pci):
                pf0 = re.sub(r"\.[0-9a-f]+$", ".0", name_to_pci[interface], flags=re.IGNORECASE)
                if not options.adapter:
                    print_normal(
                        f"{interface}: skipping, {name} shared with"
                        f" {pci_to_name[pf0]}"
                    )
                    continue

            if prompt_user:
                print_normal(f"{interface}: will be reset during {name} update")
                response = input(
                    "To continue, press Y then Enter\n"
                    "To skip this update, press Enter\n"
                )
                if (
                    len(response) > 0
                    and response.lower() == "yes"[0 : len(response)]
                ):
                    prompt_user = False
                else:
                    print_normal(f"Skipping updates for {interface}")
                    break  # exit loop as user doesn't want to update this NIC
            print_normal(f"{interface}: writing {name}")
            # user has agreed we can update
            cmd = "devlink dev flash " + pci + " file " + filename
            run_command(
                cmd,
                requires_root=True,
                changes_nic=True,
                capture_stdout=options.silent,
            )
            needs_reset = True
        elif newer:
            print_normal(f"More recent {name} found [{file_ver}]")
            print_normal("   - run `amdsfupdate --write` to perform an update")
            status = EXIT_NEEDS_UPDATE
        else:
            print_normal(f"The {name} is up to date (>={file_ver})")

    if needs_reset:
        print_normal("Resetting adapter")
        cmd = "ethtool --reset " + interface + " all"
        run_command(cmd, requires_root=True, changes_nic=True, capture_stdout=True)
        print_verbose("Complete")

sys.exit(status)
