Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,75 @@ def test_compute_max_blockwise_dynamic_range_direct():
)

print("All direct tests for compute_max_blockwise_dynamic_range passed!")


# DumpTensors tests
DUMP_TENSORS_CONFIG = """
dump:
layers:
layer_name_regex_pattern: .*
enabled: True
transformer_engine:
DumpTensors:
enabled: True
tensors: [activation]
high_precision_tensor: True
quantized_tensor: True
dump_quantized_internals: True
freq: 1
"""


def test_dump_tensors_sanity(feature_dirs):
"""Sanity test for DumpTensors feature - verify files are created with correct structure."""
if not fp8_available:
pytest.skip(reason_for_no_fp8)

with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir:
from transformer_engine.pytorch.quantization import RecipeState

recipe_state = RecipeState.create(
recipe.DelayedScaling(),
mode="forward",
num_quantizers=3,
)

tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)

debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()

# Check that dump file was created
dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0")
assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}"

dump_files = os.listdir(dump_dir)
assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}"

# Load and verify structure
dump_file = os.path.join(dump_dir, dump_files[0])
data = torch.load(dump_file, weights_only=False)

assert isinstance(data, dict), "Dump should be a dictionary"
assert "high_precision" in data, "Missing high_precision tensor"
assert "quantized" in data, "Missing quantized tensor"

# Check internals are present (dump_quantized_internals=True)
assert "data" in data, "Missing data (raw FP8 data)"
assert "scale_inv" in data, "Missing scale_inv"

# Verify tensor shapes match
assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch"

print("DumpTensors sanity test passed!")
Loading