From fe437c1c806b9f6753b0b26167d20872afd5b89d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 3 Feb 2026 08:54:26 +0000 Subject: [PATCH 1/3] code drop Signed-off-by: Pawel Gadzinski --- .../debug/features/dump_tensors.py | 327 ++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 transformer_engine/debug/features/dump_tensors.py diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py new file mode 100644 index 0000000000..e13983af67 --- /dev/null +++ b/transformer_engine/debug/features/dump_tensors.py @@ -0,0 +1,327 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DumpTensors Feature support for nvidia-dlframework-inspect.""" + +from typing import Dict, Optional + +import torch + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.logging import get_tensor_logger +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.debug.features.utils import next_enabled_iter +from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor + + +@Registry.register_feature(namespace="transformer_engine") +class DumpTensors(TEConfigAPIMapper): + """ + Dump tensors to files for debugging purposes. + + This feature saves tensors to disk using torch.save(). It supports dumping + both high-precision tensors (before quantization) and quantized tensors. + + Each tensor is saved to a separate file with the iteration number, layer name, + and tensor name in the filename. Files are organized per-rank in distributed settings. + + Parameters + ---------- + high_precision_tensor : bool + If True, dump the high-precision tensor (before quantization). + quantized_tensor : bool + If True, dump the quantized tensor (after quantization). + extended_quantized_tensor_log : bool, default = False + If True, dump additional files with raw data and scales for quantized tensors: + - For Float8Tensor: raw_data (uint8), scale_inv (FP32) + - For MXFP8Tensor: rowwise_raw_data, columnwise_raw_data (uint8), + rowwise_scale_inv, columnwise_scale_inv (decoded to FP32) + - For NVFP4Tensor: rowwise_raw_data, columnwise_raw_data (uint8), + rowwise_scale_inv, columnwise_scale_inv (decoded to FP32), + rowwise_amax, columnwise_amax (FP32) + tensors/tensors_struct : List[str] + list of tensors to dump: + - activation + - gradient + - weight + - output + - wgrad + - dgrad + freq : Optional[int], default = 1 + frequency of dumping tensors, tensors will be dumped every `freq` steps + start_step : Optional[int], default = 0 + start step of dumping tensors + end_step : Optional[int], default = -1 + end step of dumping tensors (-1 means no end) + start_end_list : Optional[list([int, int])], default = None + non-overlapping list of (start, end) pairs in incremental order. + If not None, will ignore start_step and end_step + + Example + ------- + .. code-block:: yaml + + dump_tensors_example: + enabled: True + layers: + layer_name_regex_pattern: .*(fc1|self_attention).* + transformer_engine: + DumpTensors: + enabled: True + tensors_struct: + - tensor: activation + high_precision_tensor: True + quantized_tensor: True + extended_quantized_tensor_log: True + freq: 100 + - tensor: weight + high_precision_tensor: True + quantized_tensor: False + freq: 500 + + Output Structure + ---------------- + Files are saved to: ``nvdlfw_inspect_tensor_dumps/rank_{rank}/`` + + Basic files: + - ``{layer}_{tensor}_iter_{iter}_high_precision.pt`` + - ``{layer}_{tensor}_iter_{iter}_quantized.pt`` + + Extended files (when extended_quantized_tensor_log=True): + - ``{layer}_{tensor}_iter_{iter}_raw_data.pt`` + - ``{layer}_{tensor}_iter_{iter}_scale_inv.pt`` + - (MXFP8/NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_scale_inv.pt`` + - (NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_amax.pt`` + """ + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): # pylint: disable=unused-argument + """API call used to determine whether to run inspect_tensor() in the forward.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group: torch.distributed.ProcessGroup, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): # pylint: disable=unused-argument + """ + API call used to dump tensors to files. + + Supports dumping both high-precision tensors and quantized tensors based on config. + """ + # Assert that rowwise and columnwise are the same (or one is None) + assert rowwise_quantized_tensor is columnwise_quantized_tensor, ( + "[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and " + "columnwise_quantized_tensor to be the same object or both None." + ) + + quantized_tensor = rowwise_quantized_tensor + + dump_hp = config.get("high_precision_tensor", False) + dump_quant = config.get("quantized_tensor", False) + + if not dump_hp and not dump_quant: + debug_api.log_message( + f"Feature={self.__class__.__name__}: Neither high_precision_tensor nor " + "quantized_tensor is enabled. Nothing to dump.", + layer_name, + ) + return + + tensor_logger = get_tensor_logger() + + # Dump high-precision tensor + if dump_hp and tensor is not None: + tensor_logger.save_tensor( + tensor=tensor, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_high_precision", + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped high-precision {tensor_name} at iteration {iteration}", + layer_name, + ) + + # Dump quantized tensor + if dump_quant and quantized_tensor is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_quantized", + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped quantized {tensor_name} at iteration {iteration}", + layer_name, + ) + + # Extended logging for quantized tensors + if config.get("extended_quantized_tensor_log", False): + self._dump_extended_quantized_info( + tensor_logger, quantized_tensor, layer_name, tensor_name, iteration + ) + + elif dump_quant and quantized_tensor is None: + debug_api.log_message( + f"Feature={self.__class__.__name__}: quantized_tensor is True but " + f"no quantized tensor available for {tensor_name}. Skipping.", + layer_name, + ) + + def _dump_extended_quantized_info( + self, + tensor_logger, + quantized_tensor: QuantizedTensor, + layer_name: str, + tensor_name: str, + iteration: int, + ): + """Dump extended debug info for quantized tensors (raw data and scales).""" + + if isinstance(quantized_tensor, Float8Tensor): + # Float8Tensor: raw_data (uint8), scale_inv (FP32) + tensor_logger.save_tensor( + tensor=quantized_tensor._data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_raw_data", + ) + tensor_logger.save_tensor( + tensor=quantized_tensor._scale_inv, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_scale_inv", + ) + + elif isinstance(quantized_tensor, MXFP8Tensor): + # MXFP8Tensor: raw data and scales (decoded from E8M0) + if quantized_tensor._rowwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._rowwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_raw_data", + ) + if quantized_tensor._columnwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._columnwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_raw_data", + ) + # Decode E8M0 scales to FP32 + if quantized_tensor._rowwise_scale_inv is not None: + decoded = torch.pow( + torch.tensor(2.0, device=quantized_tensor._rowwise_scale_inv.device), + quantized_tensor._rowwise_scale_inv.to(torch.float32) - 127.0, + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_scale_inv", + ) + if quantized_tensor._columnwise_scale_inv is not None: + decoded = torch.pow( + torch.tensor(2.0, device=quantized_tensor._columnwise_scale_inv.device), + quantized_tensor._columnwise_scale_inv.to(torch.float32) - 127.0, + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_scale_inv", + ) + + elif isinstance(quantized_tensor, NVFP4Tensor): + # NVFP4Tensor: raw data, scales (decoded from E4M3), and amax + if quantized_tensor._rowwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._rowwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_raw_data", + ) + if quantized_tensor._columnwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._columnwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_raw_data", + ) + # Decode E4M3 scales to FP32 + if quantized_tensor._rowwise_scale_inv is not None: + decoded = quantized_tensor._rowwise_scale_inv.view(torch.float8_e4m3fn).to( + torch.float32 + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_scale_inv", + ) + if quantized_tensor._columnwise_scale_inv is not None: + decoded = quantized_tensor._columnwise_scale_inv.view(torch.float8_e4m3fn).to( + torch.float32 + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_scale_inv", + ) + # Amax values (already FP32) + if quantized_tensor._amax_rowwise is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._amax_rowwise, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_amax", + ) + if quantized_tensor._amax_columnwise is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._amax_columnwise, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_amax", + ) From a54a743af93c198535636dad47c0c59fee9a1ba8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 3 Feb 2026 10:42:56 +0000 Subject: [PATCH 2/3] code drop Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 72 ++++ .../debug/features/dump_tensors.py | 369 ++++++++++-------- 2 files changed, 270 insertions(+), 171 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 5d6fc41ac7..5f7adc0a41 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -592,3 +592,75 @@ def test_compute_max_blockwise_dynamic_range_direct(): ) print("All direct tests for compute_max_blockwise_dynamic_range passed!") + + +# DumpTensors tests +DUMP_TENSORS_CONFIG = """ +dump: + layers: + layer_name_regex_pattern: .* + enabled: True + transformer_engine: + DumpTensors: + enabled: True + tensors: [activation] + high_precision_tensor: True + quantized_tensor: True + dump_quantized_internals: True + freq: 1 +""" + + +def test_dump_tensors_sanity(feature_dirs): + """Sanity test for DumpTensors feature - verify files are created with correct structure.""" + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: + from transformer_engine.pytorch.quantization import RecipeState + + recipe_state = RecipeState.create( + recipe.DelayedScaling(), + mode="forward", + num_quantizers=3, + ) + + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + # Check that dump file was created + dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") + assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}" + + dump_files = os.listdir(dump_dir) + assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" + + # Load and verify structure + dump_file = os.path.join(dump_dir, dump_files[0]) + data = torch.load(dump_file, weights_only=False) + + assert isinstance(data, dict), "Dump should be a dictionary" + assert "high_precision" in data, "Missing high_precision tensor" + assert "quantized" in data, "Missing quantized tensor" + + # Check internals are present (dump_quantized_internals=True) + assert "data" in data, "Missing data (raw FP8 data)" + assert "scale_inv" in data, "Missing scale_inv" + + # Verify tensor shapes match + assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" + + print("DumpTensors sanity test passed!") diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index e13983af67..136771b7ae 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -4,22 +4,95 @@ """DumpTensors Feature support for nvidia-dlframework-inspect.""" +import os from typing import Dict, Optional import torch +import torch.distributed as dist import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.logging import get_tensor_logger +from nvdlfw_inspect.logging import get_logger from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.debug.features.utils import next_enabled_iter +from transformer_engine.pytorch.constants import TE_DType_To_Torch from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor +class TensorLogger: + """Logger for saving tensors to files. Each rank saves to its own directory.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if TensorLogger._initialized: + return + self.root_dir = None + self.rank = 0 + TensorLogger._initialized = True + + def initialize(self, root_log_dir: str): + """Initialize the TensorLogger with the root directory for tensor dumps.""" + self.rank = 0 + if dist.is_initialized(): + self.rank = dist.get_rank() + + self.root_dir = os.path.join( + root_log_dir, "tensor_dumps", f"rank_{self.rank}" + ) + os.makedirs(self.root_dir, exist_ok=True) + + debug_api.log_message( + f"TensorLogger initialized. Saving tensors to: {self.root_dir}", + log_level="info", + ) + + @staticmethod + def _sanitize_name(name: str) -> str: + """Sanitize layer/tensor names for use in file paths.""" + for char in ["/", "\\", ":", "*", "?", '"', "<", ">", "|", " "]: + name = name.replace(char, "_") + return name + + def save_tensor( + self, + tensor, + layer_name: str, + tensor_name: str, + iteration: int, + ): + """Save a tensor (or dict of tensors) to a file.""" + if self.root_dir is None: + raise RuntimeError( + "[TE DumpTensors] TensorLogger not initialized. " + "Call initialize() first." + ) + + safe_layer_name = self._sanitize_name(layer_name) + safe_tensor_name = self._sanitize_name(tensor_name) + + filename = f"{safe_layer_name}_{safe_tensor_name}_iter_{iteration:06d}.pt" + filepath = os.path.join(self.root_dir, filename) + + torch.save(tensor, filepath) + + +def _get_tensor_logger() -> TensorLogger: + """Get the singleton TensorLogger instance.""" + return TensorLogger() + + @Registry.register_feature(namespace="transformer_engine") class DumpTensors(TEConfigAPIMapper): """ @@ -37,14 +110,10 @@ class DumpTensors(TEConfigAPIMapper): If True, dump the high-precision tensor (before quantization). quantized_tensor : bool If True, dump the quantized tensor (after quantization). - extended_quantized_tensor_log : bool, default = False - If True, dump additional files with raw data and scales for quantized tensors: - - For Float8Tensor: raw_data (uint8), scale_inv (FP32) - - For MXFP8Tensor: rowwise_raw_data, columnwise_raw_data (uint8), - rowwise_scale_inv, columnwise_scale_inv (decoded to FP32) - - For NVFP4Tensor: rowwise_raw_data, columnwise_raw_data (uint8), - rowwise_scale_inv, columnwise_scale_inv (decoded to FP32), - rowwise_amax, columnwise_amax (FP32) + dump_quantized_internals : bool, default = False + If True, include extracted internal data from quantized tensors + (raw data, scales, etc.) in the output dictionary. + Useful for offline analysis. Output format may change between versions. tensors/tensors_struct : List[str] list of tensors to dump: - activation @@ -78,7 +147,7 @@ class DumpTensors(TEConfigAPIMapper): - tensor: activation high_precision_tensor: True quantized_tensor: True - extended_quantized_tensor_log: True + dump_quantized_internals: True freq: 100 - tensor: weight high_precision_tensor: True @@ -87,17 +156,16 @@ class DumpTensors(TEConfigAPIMapper): Output Structure ---------------- - Files are saved to: ``nvdlfw_inspect_tensor_dumps/rank_{rank}/`` + Files are saved to: ``{nvdlfw_inspect_log_dir}/tensor_dumps/rank_{rank}/`` - Basic files: - - ``{layer}_{tensor}_iter_{iter}_high_precision.pt`` - - ``{layer}_{tensor}_iter_{iter}_quantized.pt`` + Each tensor is saved as a dictionary in a single file: + ``{layer}_{tensor}_iter_{iter:06d}.pt`` - Extended files (when extended_quantized_tensor_log=True): - - ``{layer}_{tensor}_iter_{iter}_raw_data.pt`` - - ``{layer}_{tensor}_iter_{iter}_scale_inv.pt`` - - (MXFP8/NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_scale_inv.pt`` - - (NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_amax.pt`` + Dictionary keys: + - ``high_precision``: pre-quantization tensor (if high_precision_tensor=True) + - ``quantized``: quantized tensor object (if quantized_tensor=True) + - Additional internal components when dump_quantized_internals=True + (raw data, scales, etc. - format may change between versions) """ @api_method @@ -151,43 +219,23 @@ def inspect_tensor( ) return - tensor_logger = get_tensor_logger() + tensor_logger = _get_tensor_logger() + if tensor_logger.root_dir is None: + tensor_logger.initialize(get_logger().root_log_dir) + + # Build dictionary with all tensors to dump + dump_dict: Dict[str, torch.Tensor] = {} - # Dump high-precision tensor if dump_hp and tensor is not None: - tensor_logger.save_tensor( - tensor=tensor, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_high_precision", - ) - debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor: " - f"Dumped high-precision {tensor_name} at iteration {iteration}", - layer_name, - ) + dump_dict["high_precision"] = tensor - # Dump quantized tensor if dump_quant and quantized_tensor is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_quantized", - ) - debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor: " - f"Dumped quantized {tensor_name} at iteration {iteration}", - layer_name, - ) + dump_dict["quantized"] = quantized_tensor - # Extended logging for quantized tensors - if config.get("extended_quantized_tensor_log", False): - self._dump_extended_quantized_info( - tensor_logger, quantized_tensor, layer_name, tensor_name, iteration - ) + # Add internals for quantized tensors + if config.get("dump_quantized_internals", False): + internals = self._get_quantized_internals(quantized_tensor) + dump_dict.update(internals) elif dump_quant and quantized_tensor is None: debug_api.log_message( @@ -196,132 +244,111 @@ def inspect_tensor( layer_name, ) - def _dump_extended_quantized_info( - self, - tensor_logger, - quantized_tensor: QuantizedTensor, - layer_name: str, - tensor_name: str, - iteration: int, - ): - """Dump extended debug info for quantized tensors (raw data and scales).""" - - if isinstance(quantized_tensor, Float8Tensor): - # Float8Tensor: raw_data (uint8), scale_inv (FP32) + if dump_dict: tensor_logger.save_tensor( - tensor=quantized_tensor._data, + tensor=dump_dict, layer_name=layer_name, tensor_name=tensor_name, iteration=iteration, - suffix="_raw_data", ) - tensor_logger.save_tensor( - tensor=quantized_tensor._scale_inv, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_scale_inv", + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped {tensor_name} at iteration {iteration} (keys: {list(dump_dict.keys())})", + layer_name, ) + def _get_quantized_internals( + self, + quantized_tensor: QuantizedTensor, + ) -> Dict[str, torch.Tensor]: + """Get internal components of quantized tensors (raw data, scales, etc.).""" + if isinstance(quantized_tensor, Float8Tensor): + tensors = _get_extended_tensors_fp8(quantized_tensor) + elif isinstance(quantized_tensor, Float8BlockwiseQTensor): + tensors = _get_extended_tensors_fp8_blockwise(quantized_tensor) elif isinstance(quantized_tensor, MXFP8Tensor): - # MXFP8Tensor: raw data and scales (decoded from E8M0) - if quantized_tensor._rowwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._rowwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_raw_data", - ) - if quantized_tensor._columnwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._columnwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_raw_data", - ) - # Decode E8M0 scales to FP32 - if quantized_tensor._rowwise_scale_inv is not None: - decoded = torch.pow( - torch.tensor(2.0, device=quantized_tensor._rowwise_scale_inv.device), - quantized_tensor._rowwise_scale_inv.to(torch.float32) - 127.0, - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_scale_inv", - ) - if quantized_tensor._columnwise_scale_inv is not None: - decoded = torch.pow( - torch.tensor(2.0, device=quantized_tensor._columnwise_scale_inv.device), - quantized_tensor._columnwise_scale_inv.to(torch.float32) - 127.0, - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_scale_inv", - ) - + tensors = _get_extended_tensors_mxfp8(quantized_tensor) elif isinstance(quantized_tensor, NVFP4Tensor): - # NVFP4Tensor: raw data, scales (decoded from E4M3), and amax - if quantized_tensor._rowwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._rowwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_raw_data", - ) - if quantized_tensor._columnwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._columnwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_raw_data", - ) - # Decode E4M3 scales to FP32 - if quantized_tensor._rowwise_scale_inv is not None: - decoded = quantized_tensor._rowwise_scale_inv.view(torch.float8_e4m3fn).to( - torch.float32 - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_scale_inv", - ) - if quantized_tensor._columnwise_scale_inv is not None: - decoded = quantized_tensor._columnwise_scale_inv.view(torch.float8_e4m3fn).to( - torch.float32 - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_scale_inv", - ) - # Amax values (already FP32) - if quantized_tensor._amax_rowwise is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._amax_rowwise, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_amax", - ) - if quantized_tensor._amax_columnwise is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._amax_columnwise, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_amax", - ) + tensors = _get_extended_tensors_nvfp4(quantized_tensor) + else: + return {} + + # Filter out None values + return {k: v for k, v in tensors.items() if v is not None} + + +def _get_extended_tensors_fp8(tensor: Float8Tensor) -> Dict[str, torch.Tensor]: + """Get extended tensors for Float8Tensor: raw FP8 data, transpose, and scale.""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result = { + "data": tensor._data.view(torch_fp8_dtype), + "scale_inv": tensor._scale_inv, + } + if tensor._transpose is not None and not tensor._transpose_invalid: + result["transpose"] = tensor._transpose.view(torch_fp8_dtype) + return result + + +def _get_extended_tensors_fp8_blockwise( + tensor: Float8BlockwiseQTensor, +) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for Float8BlockwiseQTensor: raw FP8 data and block scales.""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result: Dict[str, Optional[torch.Tensor]] = {} + + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) + + # Block scaling factors (FP32) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv + + return result + + +def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for MXFP8Tensor: raw FP8 data and block scales (E8M0).""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result: Dict[str, Optional[torch.Tensor]] = {} + + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) + + # Block scaling factors (E8M0 format) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu) + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e8m0fnu) + + return result + + +def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for NVFP4Tensor: raw packed FP4 data, block scales, and amax.""" + result: Dict[str, Optional[torch.Tensor]] = {} + + # Raw data (packed FP4, 2 values per byte) + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data + + # Block scaling factors (E4M3 format) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn) + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e4m3fn) + + # Input absolute maximum value (used to compute tensor scale) + if tensor._amax_rowwise is not None: + result["amax_rowwise"] = tensor._amax_rowwise + if tensor._amax_columnwise is not None: + result["amax_columnwise"] = tensor._amax_columnwise + + return result From b6e076734dc791e178bb68681a3c74fa7c119a77 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:45:41 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/dump_tensors.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 136771b7ae..b453b5e273 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -48,9 +48,7 @@ def initialize(self, root_log_dir: str): if dist.is_initialized(): self.rank = dist.get_rank() - self.root_dir = os.path.join( - root_log_dir, "tensor_dumps", f"rank_{self.rank}" - ) + self.root_dir = os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}") os.makedirs(self.root_dir, exist_ok=True) debug_api.log_message( @@ -75,8 +73,7 @@ def save_tensor( """Save a tensor (or dict of tensors) to a file.""" if self.root_dir is None: raise RuntimeError( - "[TE DumpTensors] TensorLogger not initialized. " - "Call initialize() first." + "[TE DumpTensors] TensorLogger not initialized. Call initialize() first." ) safe_layer_name = self._sanitize_name(layer_name) @@ -324,7 +321,9 @@ def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch if tensor._rowwise_scale_inv is not None: result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu) if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e8m0fnu) + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( + torch.float8_e8m0fnu + ) return result @@ -343,7 +342,9 @@ def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch if tensor._rowwise_scale_inv is not None: result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn) if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e4m3fn) + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( + torch.float8_e4m3fn + ) # Input absolute maximum value (used to compute tensor scale) if tensor._amax_rowwise is not None: