#!/usr/bin/env python3
#
# SPDX-License-Identifier: GPL-2.0-or-later
"""
Decode an LPDDR5/LPDDR5X SPD text hex file into JSON.

Default output format is compatible with coreboot's spd_tools `spd_gen`:
it emits a `memory_parts.json` containing a single part entry.

This tool is intentionally named `spd-decode` even though it currently only
supports LPDDR5/LPDDR5X SPDs, so it can be extended to other SPD types later.

Usage:
  spd-decode <path/to/spd.hex> [--pretty] [--name NAME]
            [--include-nondefault-timings]
  spd-decode <path/to/spd.hex> --format raw [--pretty]

Options:
  --format spd_gen      Output a `memory_parts.json`-style document for
                        coreboot's spd_tools `spd_gen`.
  --format raw          Output the fully decoded SPD fields as JSON.
  --name NAME           Override the part name in spd_gen output.
  --include-nondefault-timings
                        Include timing fields in spd_gen output only when they
                        differ from spd_gen defaults.
  --pretty              Pretty-print JSON (tabs) and include a command header.

Only LPDDR5 (0x13) and LPDDR5X (0x15) SPDs are supported.
"""

import argparse
import json
import os
import sys
from typing import Dict, List, Optional


def _json_dumps_tabs(obj) -> str:
	# Dump JSON with one-space indent and convert leading spaces to tabs.
	s = json.dumps(obj, indent=1, separators=(",", ": "), ensure_ascii=False)
	lines = []
	for line in s.splitlines():
		n = 0
		while n < len(line) and line[n] == " ":
			n += 1
		if n:
			line = ("\t" * n) + line[n:]
		lines.append(line)
	return "\n".join(lines) + "\n"


def _to_signed_byte(b: int) -> int:
	return b - 256 if b >= 128 else b


def _parse_hex_file(path: str) -> List[int]:
	data: List[int] = []
	with open(path, "r", encoding="utf-8") as f:
		for line in f:
			line = line.strip()
			if not line:
				continue
			for p in line.split():
				if len(p) != 2:
					raise ValueError(f"Invalid hex byte '{p}' in {path}")
				data.append(int(p, 16))
	return data


def _decode_time_ps(mtb_byte: int, ftb_byte: int) -> int:
	# MTB unit = 125 ps. FTB is signed 8-bit offset in ps.
	return (mtb_byte * 125) + _to_signed_byte(ftb_byte)


def _decode_time_ns_from_16bit_mtb(lsb: int, msb: int) -> float:
	# 16-bit MTB value in units of 125ps -> convert to ns (float)
	mtb = (msb << 8) | lsb
	ps = mtb * 125
	return ps / 1000.0


def _round_int(x: float) -> int:
	return int(round(x))


def _closest_speed_from_tck_ps(tck_ps: int, set_rev: int) -> int:
	# Allowed LP5 speed bins in Mbps
	bins = [5500, 6400, 7500, 8533]
	if set_rev == 0x10:
		# ADL set encodes CK tCKmin = 1 / (WCK rate / 4) = 8e6 / speed
		guess = 8_000_000.0 / max(tck_ps, 1)
	else:
		# Default SPD encodes WCK tCKmin = 2e6 / speed
		guess = 2_000_000.0 / max(tck_ps, 1)
	return int(min(bins, key=lambda b: abs(b - guess)))


def _default_part_name_from_path(path: str) -> str:
	base = os.path.basename(path)
	for suffix in (".spd.hex", ".hex"):
		if base.endswith(suffix):
			base = base[: -len(suffix)]
			break
	return base


def _lp5_defaults_for_optional_attribs(density_gb: Optional[int]) -> Dict[str, int]:
	# Match util/spd_tools/src/spd_gen/lp5.go defaults.
	trfc = {
		4: (180, 90),
		6: (210, 120),
		8: (210, 120),
		12: (280, 140),
		16: (280, 140),
		24: (380, 190),
		32: (380, 190),
	}
	trfcab, trfcpb = trfc.get(density_gb, (0, 0))
	return {
		"trfcabNs": trfcab,
		"trfcpbNs": trfcpb,
		"trcdMinNs": 18,
		"trpabMinNs": 21,
		"trppbMinNs": 18,
	}


def decode_lp5(spd: List[int]) -> Dict:
	if len(spd) < 349:
		raise ValueError("LPDDR5 SPD must contain at least 349 bytes")

	# Byte indices (match util/spd_tools/src/spd_gen/lp5.go)
	IDX = {
		"REV": 1,
		"MEM_TYPE": 2,
		"DENSITY_BANKS": 4,
		"ADDR": 5,
		"PKG_TYPE": 6,
		"OPT_FEATURES": 7,
		"OTHER_OPT_FEATURES": 9,
		"MODULE_ORG": 12,
		"BUS_WIDTH": 13,
		"TIMEBASES": 17,
		"TCK_MIN": 18,
		"TAA_MIN": 24,
		"TRCD_MIN": 26,
		"TRPAB_MIN": 27,
		"TRPPB_MIN": 28,
		"TRFCAB_LSB": 29,
		"TRFCAB_MSB": 30,
		"TRFCPB_LSB": 31,
		"TRFCPB_MSB": 32,
		"TRPPB_FINE": 120,
		"TRPAB_FINE": 121,
		"TRCD_FINE": 122,
		"TAA_FINE": 123,
		"TCK_FINE": 125,
		"MPN_START": 329,
		"MPN_END": 348,
	}

	mem_type = spd[IDX["MEM_TYPE"]]
	rev = spd[IDX["REV"]]
	if mem_type == 0x13:
		mem_type_str = "LPDDR5"
		lp5x = False
	elif mem_type == 0x15:
		mem_type_str = "LPDDR5X"
		lp5x = True
	else:
		mem_type_str = f"Unknown(0x{mem_type:02x})"
		lp5x = False

	density_code = spd[IDX["DENSITY_BANKS"]] & 0x0F
	density_map = {
		0x4: 4,
		0xB: 6,
		0x5: 8,
		0x8: 12,
		0x6: 16,
		0x9: 24,
		0x7: 32,
	}
	density_per_die_gb = density_map.get(density_code)

	dies_per_package = ((spd[IDX["PKG_TYPE"]] >> 4) & 0x7) + 1
	ranks_per_channel = ((spd[IDX["MODULE_ORG"]] >> 3) & 0x7) + 1
	bit_width_per_channel = (spd[IDX["MODULE_ORG"]] & 0x7) * 8
	channels_per_pkg = 1 << ((spd[IDX["PKG_TYPE"]] >> 2) & 0x3)

	tck_ps = _decode_time_ps(spd[IDX["TCK_MIN"]], spd[IDX["TCK_FINE"]])
	speed_mbps = _closest_speed_from_tck_ps(tck_ps, rev)
	if not lp5x and speed_mbps >= 7500:
		lp5x = True

	taa_ps = _decode_time_ps(spd[IDX["TAA_MIN"]], spd[IDX["TAA_FINE"]])
	trcd_ns = _round_int(_decode_time_ps(spd[IDX["TRCD_MIN"]], spd[IDX["TRCD_FINE"]]) / 1000.0)
	trpab_ns = _round_int(_decode_time_ps(spd[IDX["TRPAB_MIN"]], spd[IDX["TRPAB_FINE"]]) / 1000.0)
	trppb_ns = _round_int(_decode_time_ps(spd[IDX["TRPPB_MIN"]], spd[IDX["TRPPB_FINE"]]) / 1000.0)
	trfcab_ns = _round_int(_decode_time_ns_from_16bit_mtb(spd[IDX["TRFCAB_LSB"]], spd[IDX["TRFCAB_MSB"]]))
	trfcpb_ns = _round_int(_decode_time_ns_from_16bit_mtb(spd[IDX["TRFCPB_LSB"]], spd[IDX["TRFCPB_MSB"]]))

	mpn_bytes = spd[IDX["MPN_START"] : IDX["MPN_END"] + 1]
	try:
		mpn = bytes(mpn_bytes).decode("ascii").strip()
	except Exception:
		mpn = ""

	return {
		"memoryType": mem_type_str,
		"revision": f"0x{rev:02x}",
		"channelsPerPackage": channels_per_pkg,
		"manufacturerPartNumber": mpn or None,
		"attributes": {
			"densityPerDieGb": density_per_die_gb,
			"diesPerPackage": dies_per_package,
			"bitWidthPerChannel": bit_width_per_channel,
			"ranksPerChannel": ranks_per_channel,
			"speedMbps": speed_mbps,
			"lp5x": lp5x,
			"tckMinPs": tck_ps,
			"taaMinPs": taa_ps,
			"trcdMinNs": trcd_ns,
			"trpabMinNs": trpab_ns,
			"trppbMinNs": trppb_ns,
			"trfcabNs": trfcab_ns,
			"trfcpbNs": trfcpb_ns,
		},
	}


def detect_and_decode(spd: List[int]) -> Dict:
	if len(spd) < 3:
		raise ValueError("SPD data too short")
	mem_type = spd[2]
	if mem_type in (0x13, 0x15):
		return decode_lp5(spd)
	raise NotImplementedError(f"Unsupported or unknown SPD memory type 0x{mem_type:02x}")


def _to_spd_gen_memory_parts(decoded: Dict, part_name: str, include_nondefault_timings: bool) -> Dict:
	attribs_in = decoded.get("attributes", {})

	attribs_out = {
		"densityPerDieGb": attribs_in.get("densityPerDieGb"),
		"diesPerPackage": attribs_in.get("diesPerPackage"),
		"bitWidthPerChannel": attribs_in.get("bitWidthPerChannel"),
		"ranksPerChannel": attribs_in.get("ranksPerChannel"),
		"speedMbps": attribs_in.get("speedMbps"),
	}

	if attribs_in.get("lp5x"):
		attribs_out["lp5x"] = True

	if include_nondefault_timings:
		defaults = _lp5_defaults_for_optional_attribs(attribs_in.get("densityPerDieGb"))
		for k in ("trfcabNs", "trfcpbNs", "trcdMinNs", "trpabMinNs", "trppbMinNs"):
			v = attribs_in.get(k)
			if v is None:
				continue
			if defaults.get(k) and v != defaults[k]:
				attribs_out[k] = v

	part = {"name": part_name, "attribs": attribs_out}
	return {"parts": [part]}


def main(argv: List[str]) -> int:
	ap = argparse.ArgumentParser(description="Decode LPDDR5 SPD .hex file to JSON")
	ap.add_argument("hexfile", help="Path to SPD hex file (text with space-separated bytes)")
	ap.add_argument(
		"--format",
		choices=("spd_gen", "raw"),
		default="spd_gen",
		help="Output JSON format: spd_gen (memory_parts.json) or raw",
	)
	ap.add_argument(
		"--name",
		help="Part name for spd_gen output (defaults to MPN or input filename)",
	)
	ap.add_argument(
		"--include-nondefault-timings",
		action="store_true",
		help="Include timing fields only when they differ from spd_gen defaults (spd_gen output only)",
	)
	ap.add_argument(
		"--pretty",
		action="store_true",
		help="Pretty-print JSON output with a command header",
	)
	args = ap.parse_args(argv)

	spd = _parse_hex_file(args.hexfile)
	decoded = detect_and_decode(spd)

	if args.format == "raw":
		out_obj = decoded
	else:
		part_name = args.name or decoded.get("manufacturerPartNumber") or _default_part_name_from_path(args.hexfile)
		out_obj = _to_spd_gen_memory_parts(decoded, part_name, args.include_nondefault_timings)

	if args.pretty:
		header = f"// Generated by:\n// {' '.join([sys.argv[0]] + argv)}\n\n"
		print(header + _json_dumps_tabs(out_obj), end="")
	else:
		print(json.dumps(out_obj, separators=(",", ":"), ensure_ascii=False))
	return 0


if __name__ == "__main__":
	raise SystemExit(main(sys.argv[1:]))
