
#
# Copyright (c) 2004-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited
#


import sys
import json
import subprocess
from mlxlink_multiplane_aggregator import *


class MlxlinkWrapper:
    def __init__(self):
        self.stdout = None
        self.stderr = None

    def run_mlxlink(self, mlxlink_arguments):
        result = subprocess.run(
            mlxlink_arguments,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
            check=True,
        )
        self.stdout = result.stdout
        self.stderr = result.stderr
        return result.returncode


def find_device_arg_index(argv):
    # Find the index of the device given as argument to mlxlink command.
    for index, arg in enumerate(argv):
        if arg == '-d' or arg == '--device':
            return index + 1


def find_amber_collect_arg_index(argv):
    # Find the index of the amber_collect flag given as argument to mlxlink command.
    for index, arg in enumerate(argv):
        if arg == '--amber_collect':
            return index + 1
    return -1


def get_asic_info():
    aggregated_port_fd = sys.argv[find_device_arg_index(sys.argv)]
    try:
        with open(aggregated_port_fd) as f:
            aggregated_port_info = json.load(f)
    except FileNotFoundError:
        raise Exception("Unknown device. Please provide the correct device.")
    # Get the system type (BM \ CPO system)
    for info in aggregated_port_info:
        sys_type = info.get('System Type')
        if sys_type is not None:
            break
    # Each ASIC is a json entry in a list.
    # A tuple (asic#, split#) will be stored.
    asics = []
    for info in aggregated_port_info:
        device, split = info.get('Device'), info.get('Split')
        if device is not None and split is not None:
            asics.append((device, split))
    return asics, sys_type


def main(mlxlink_arguments):

    try:
        devices, system_type = get_asic_info()
        # Read ASIC mapping and splits from the aggregated port device
        # Run the same commands but replace the device according to the asic map and required split
        # Add the port for the command
        lines, parsed_output = [], []
        rc_lst, err_lst = [], []
        asic_num = 0
        amber_index = find_amber_collect_arg_index(mlxlink_arguments)
        if amber_index >= len(mlxlink_arguments):
            raise Exception("Error - when using amber_collect flag, a csv file must be provided!")
        if amber_index != -1:
            mlxlink_arguments[amber_index] = mlxlink_arguments[amber_index].replace('.csv', '_ASIC_0.csv')
        for device, split in devices:
            mlxlink_wrapper = MlxlinkWrapper()
            # replace aggregated port device with actual lid device.
            mlxlink_arguments[find_device_arg_index(mlxlink_arguments)] = device
            if amber_index != -1:
                mlxlink_arguments[amber_index] = mlxlink_arguments[amber_index].replace('_ASIC_{}'.format(asic_num - 1), '_ASIC_{}'.format(asic_num))
            asic_num += 1
            mlxlink_arguments.append("--planarized_split")  # TODO: fix the wrong concat in here (small issue)
            mlxlink_arguments.append(str(split))
            try:
                rc = mlxlink_wrapper.run_mlxlink(mlxlink_arguments)
            except subprocess.CalledProcessError as e:
                print(f"mlxlink returned a non-zero exit code: {e.returncode}")
                err_msg = e.stderr.strip()
                if err_msg:
                    print(f"{e.stderr.strip()}")
                else:
                    print(f"{e.stdout.strip()}")
                return 1
            except FileNotFoundError:
                print("mlxlink executable not found. Please provide the correct path.")
                return 1

            lines += mlxlink_wrapper.stdout.strip().split('\n')
            parser = AggragetedParser(mlxlink_wrapper.stdout.strip().split('\n'))
            parsed_output.append(parser.parse_mlxlink_output())
            rc_lst.append(rc)
            err_lst.append(mlxlink_wrapper.stderr.strip())
        multiplane_aggregator(mlxlink_arguments, devices, rc_lst, parsed_output, err_lst, system_type)
        rc = max(rc_lst)
    except Exception as e:
        print(e)
        return 1
    return rc


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python script.py mlxlink_arguments")
        sys.exit(1)
    main(sys.argv[1:])
