Skip to content

Commit

Permalink
add debug utils for cache sharding and size across chips
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed May 7, 2024
1 parent d590328 commit 472dd94
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
35 changes: 34 additions & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
import jax
import flax
import json
import sys

Expand All @@ -27,10 +28,40 @@
import maxtext_utils
import pyconfig


_WARMUP_ITERS = 2


def _debug_pytree(pytree):
"""Debug pytree sizing and sharding across chips."""
if isinstance(pytree, flax.linen.spmd.LogicallyPartitioned):
pytree = pytree.value
print(f"\tshape: {pytree.shape}")
print(f"\tsharding: {pytree.sharding}")
total_logical_sizes, total_logical_bytes, _ = max_utils.summarize_size_from_pytree(pytree)
total_physical_sizes_across_chips, n_chips = max_utils.calculate_total_params_across_chips(pytree)
total_physical_bytes_across_chip, _ = max_utils.calculate_total_bytes_across_chips(pytree)
print(f"\ttotal_logical_sizes: {total_logical_sizes}")
print(f"\ttotal_logical_bytes: {total_logical_bytes}", )
print(f"\tn_chips: {n_chips}")
print(f"\ttotal_physical_sizes_across_chips: {total_physical_sizes_across_chips}")
print(f"\ttotal_physical_bytes_across_chip: {total_physical_bytes_across_chip}")


def debug_result(result):
"""Debug result pytrees' sizing and sharding across chips."""
for result_key in result.keys():
result_element = result[result_key]
if result_key == "cache":
singler_layer_kv_cache = result["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
for cache_key in singler_layer_kv_cache.keys():
cache_element = singler_layer_kv_cache[cache_key]
print(f"{cache_key}:")
_debug_pytree(cache_element)
else:
print(f"{result_key}:")
_debug_pytree(result_element)


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -204,6 +235,7 @@ def summarize_prefill_result(engine, params, tokens, true_length):
num_prefill_cache_params, total_prefill_cache_size, avg_prefill_cache_param_size = (
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
)
debug_result(prefill_result)
max_utils.delete_pytree(prefill_result)
return {
"num_prefill_logits_params": num_prefill_logits_params,
Expand All @@ -227,6 +259,7 @@ def main(config):
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()

_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

Expand Down
39 changes: 34 additions & 5 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,64 @@ def l2norm_pytree(x):
return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0))


def calculate_num_params_from_pytree(params):
def calculate_num_params_from_pytree(params) -> int:
"""Calculate params' logical size."""
params_sizes = jax.tree_util.tree_map(jax.numpy.size, params)
total_parameters = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes)
assert total_parameters >= 0
return total_parameters


def calculate_total_params_per_chip(params):
def calculate_total_params_per_chip(params) -> int:
"""Calculate params' physical size on an addressable chip."""
def calculate_leaf_params_per_chip(arr):
shard = arr.addressable_shards[0]
return np.prod(shard.data.shape)
return shard.data.size

params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params)
total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip)
return total_parameters_per_chip


def calculate_bytes_from_pytree(params):
def calculate_bytes_from_pytree(params) -> int:
"""Calculate params' logical bytes."""
params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params)
total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes)
return total_bytes


def summarize_size_from_pytree(params):
def summarize_size_from_pytree(params) -> tuple[int, int, float]:
"""Calculate params' logical size and bytes."""
num_params = calculate_num_params_from_pytree(params)
num_bytes = calculate_bytes_from_pytree(params)
return num_params, num_bytes, num_bytes / num_params


def calculate_total_bytes_per_chip(params) -> int:
"""Calculate params' physical bytes on an addressable chip."""
def calculate_leaf_params_bytes_per_chip(arr):
shard = arr.addressable_shards[0]
return shard.data.nbytes

params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_bytes_per_chip, params)
total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip)
return total_parameters_per_chip


def calculate_total_params_across_chips(params) -> list[int]:
"""Calculate params' physical sizes across all chips."""
total_params_per_chip = calculate_total_params_per_chip(params)
n_chips = len(params.addressable_shards)
return total_params_per_chip * n_chips, n_chips


def calculate_total_bytes_across_chips(params) -> list[int]:
"""Calculate params' physical bytes across all chips."""
total_bytes_per_chip = calculate_total_bytes_per_chip(params)
n_chips = len(params.addressable_shards)
return total_bytes_per_chip * n_chips, n_chips


def activate_profiler(config, optional_postfix=""):
if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0):
output_path = os.path.join(config.tensorboard_dir, optional_postfix)
Expand Down

0 comments on commit 472dd94

Please sign in to comment.