Skip to content

Commit

Permalink
add debug for cache sharding and size across chips
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 26, 2024
1 parent 6570445 commit 357b36e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
26 changes: 25 additions & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Inference microbenchmark for prefill and autoregressive steps."""
import datetime
import jax
import flax
import json
import sys

Expand All @@ -27,10 +28,31 @@
import maxtext_utils
import pyconfig


_WARMUP_ITERS = 2


def debug_kv_cache(kv_cache):
"""Debug KV Cache sizing and sharding across chips."""
singler_layer_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"]
for cache_key in singler_layer_kv_cache.keys():
cache_element = singler_layer_kv_cache[cache_key]
print(f"{cache_key=}")
if isinstance(cache_element, flax.linen.spmd.LogicallyPartitioned):
cache_element = cache_element.value
jax.debug.print(" shape: {}", cache_element.shape)
jax.debug.print(" sharding: {}", cache_element.sharding)
total_logical_sizes, total_logical_bytes, _ = max_utils.summarize_size_from_pytree(cache_element)
total_sizes_across_chips, sizes_per_chip, num_chips = max_utils.calculate_total_params_across_chip(cache_element)
total_bytes_across_chip, bytes_per_chip, _ = max_utils.calculate_total_bytes_across_chip(cache_element)
jax.debug.print(" total_logical_sizes: {}", total_logical_sizes)
jax.debug.print(" total_logical_bytes: {}", total_logical_bytes)
jax.debug.print(" num_chips: {}", num_chips)
jax.debug.print(" total_sizes_across_chips: {}", total_sizes_across_chips)
jax.debug.print(" sizes_per_chip: {}", sizes_per_chip)
jax.debug.print(" total_bytes_across_chip: {}", total_bytes_across_chip)
jax.debug.print(" bytes_per_chip: {}", bytes_per_chip)


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
Expand Down Expand Up @@ -198,6 +220,7 @@ def summarize_prefill_result(engine, params, tokens, true_length):
print(f"Prefill result of length {tokens.size}:\n")
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
jax.block_until_ready(prefill_result)
debug_kv_cache(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = (
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
)
Expand Down Expand Up @@ -227,6 +250,7 @@ def main(config):
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()

_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

Expand Down
20 changes: 20 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def summarize_size_from_pytree(params):
return num_params, num_bytes, num_bytes / num_params


def calculate_total_params_across_chip(params):
def calculate_sizes_per_chip(arr):
return [np.prod(shard.data.shape) for shard in arr.addressable_shards]
sizes_across_chips = jax.tree_util.tree_map(calculate_sizes_per_chip, params)
num_chips = len(sizes_across_chips)
total_sizes_across_chips = jax.tree_util.tree_reduce(lambda x, y: x + y, sizes_across_chips)
sizes_per_chip = total_sizes_across_chips / num_chips
return total_sizes_across_chips, sizes_per_chip, num_chips


def calculate_total_bytes_across_chip(params):
def calculate_bytes_across_chip(arr):
return [shard.data.nbytes for shard in arr.addressable_shards]
bytes_across_chips = jax.tree_util.tree_map(calculate_bytes_across_chip, params)
num_chips = len(bytes_across_chips)
total_bytes_across_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, bytes_across_chips)
bytes_per_chip = total_bytes_across_chip / num_chips
return total_bytes_across_chip, bytes_per_chip, num_chips


def activate_profiler(config, optional_postfix=""):
if config.enable_profiler and (config.upload_all_profiler_results or jax.process_index() == 0):
output_path = os.path.join(config.tensorboard_dir, optional_postfix)
Expand Down

0 comments on commit 357b36e

Please sign in to comment.