Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add debug functionality for per chip sizes and bytes #625

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 33 additions & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
import jax
import flax
import json
import sys

Expand All @@ -27,10 +28,40 @@
import maxtext_utils
import pyconfig


_WARMUP_ITERS = 2


def _debug_pytree(pytree):
"""Debug pytree sizing and sharding across chips."""
if isinstance(pytree, flax.linen.spmd.LogicallyPartitioned):
pytree = pytree.value
print(f"\tshape: {pytree.shape}")
print(f"\tsharding: {pytree.sharding}")
total_logical_sizes, total_logical_bytes, _ = max_utils.summarize_size_from_pytree(pytree)
total_physical_sizes_across_chips, n_chips = max_utils.calculate_total_params_across_chips(pytree)
total_physical_bytes_across_chip, _ = max_utils.calculate_total_bytes_across_chips(pytree)
print(f"\ttotal_logical_sizes: {total_logical_sizes}")
print(f"\ttotal_logical_bytes: {total_logical_bytes}", )
print(f"\tn_chips: {n_chips}")
print(f"\ttotal_physical_sizes_across_chips: {total_physical_sizes_across_chips}")
print(f"\ttotal_physical_bytes_across_chip: {total_physical_bytes_across_chip}")


def debug_result(result):
"""Debug result pytrees' sizing and sharding across chips."""
for result_key in result.keys():
result_element = result[result_key]
if result_key == "cache":
singler_layer_kv_cache = result["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
for cache_key in singler_layer_kv_cache.keys():
cache_element = singler_layer_kv_cache[cache_key]
print(f"{cache_key}:")
_debug_pytree(cache_element)
else:
print(f"{result_key}:")
_debug_pytree(result_element)


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -204,6 +235,7 @@ def summarize_prefill_result(engine, params, tokens, true_length):
num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = (
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
)
debug_result(prefill_result)
max_utils.delete_pytree(prefill_result)
return {
"num_prefill_logits_params": num_prefill_logits_params,
Expand Down
39 changes: 34 additions & 5 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,64 @@ def l2norm_pytree(x):
return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0))


def calculate_num_params_from_pytree(params):
def calculate_num_params_from_pytree(params) -> int:
"""Calculate params' logical size."""
params_sizes = jax.tree_util.tree_map(jax.numpy.size, params)
total_parameters = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes)
assert total_parameters >= 0
return total_parameters


def calculate_total_params_per_chip(params):
def calculate_total_params_per_chip(params) -> int:
"""Calculate params' physical size on an addressable chip."""
def calculate_leaf_params_per_chip(arr):
shard = arr.addressable_shards[0]
return np.prod(shard.data.shape)
return shard.data.size

params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params)
total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip)
return total_parameters_per_chip


def calculate_bytes_from_pytree(params):
def calculate_bytes_from_pytree(params) -> int:
"""Calculate params' logical bytes."""
params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params)
total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes)
return total_bytes


def summarize_size_from_pytree(params):
def summarize_size_from_pytree(params) -> tuple[int, int, float]:
"""Calculate params' logical size and bytes."""
num_params = calculate_num_params_from_pytree(params)
num_bytes = calculate_bytes_from_pytree(params)
return num_params, num_bytes, num_bytes / num_params


def calculate_total_bytes_per_chip(params) -> int:
"""Calculate params' physical bytes on an addressable chip."""
def calculate_leaf_params_bytes_per_chip(arr):
shard = arr.addressable_shards[0]
return shard.data.nbytes

params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_bytes_per_chip, params)
total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip)
return total_parameters_per_chip


def calculate_total_params_across_chips(params) -> tuple[int, int]:
"""Calculate params' physical sizes across all chips."""
total_params_per_chip = calculate_total_params_per_chip(params)
n_chips = len(params.addressable_shards)
return total_params_per_chip * n_chips, n_chips


def calculate_total_bytes_across_chips(params) -> tuple[int, int]:
"""Calculate params' physical bytes across all chips."""
total_bytes_per_chip = calculate_total_bytes_per_chip(params)
n_chips = len(params.addressable_shards)
return total_bytes_per_chip * n_chips, n_chips


def activate_profiler(config, optional_postfix=""):
if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0):
output_path = os.path.join(config.tensorboard_dir, optional_postfix)
Expand Down