New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add debug functionality for per chip sizes and bytes #625
base: main
Are you sure you want to change the base?
Conversation
0a1a790
to
8ee7739
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, some minor comments.
MaxText/inference_microbenchmark.py
Outdated
_WARMUP_ITERS = 2 | ||
|
||
|
||
def debug_kv_cache(kv_cache): | ||
singler_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Is this supposed to be "single" or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added "single_layer"
MaxText/inference_microbenchmark.py
Outdated
singler_kv_cache = kv_cache["cache"]["decoder"]["layers_0"]["self_attention"]["AttentionOp_0"] | ||
for cache_key in singler_kv_cache.keys(): | ||
cache_element = singler_kv_cache[cache_key] | ||
print(f"{cache_key}:") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Would be helpful to print out what the variable name is. You can do this in f-strings by adding an =
like this. print(f"{cache_key=}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/inference_microbenchmark.py
Outdated
print(f"{cache_key}:") | ||
if type(cache_element) == flax.linen.spmd.LogicallyPartitioned: | ||
cache_element = cache_element.value | ||
jax.debug.print(" shape: {shape}", shape=cache_element.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This is a dense series of lines, some whitespace would help make it more readable.
A small thing that you can take or leave related to density is that in jax.debug.print()
you can ignore the var naming if you are only printing one var. Like this jax.debug.print(" sharding: {}", cache_element.sharding)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/inference_microbenchmark.py
Outdated
@@ -227,6 +255,8 @@ def main(config): | |||
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) | |||
|
|||
decode_state = engine.init_decode_state() | |||
debug_kv_cache(decode_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to run this twice in the script?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can make this optional, I was also checking the decode_state, which was sharded correctly.
8ee7739
to
e3c5fb3
Compare
4bdd2c5
to
357b36e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits
MaxText/inference_microbenchmark.py
Outdated
print(f"{cache_key=}") | ||
if isinstance(cache_element, flax.linen.spmd.LogicallyPartitioned): | ||
cache_element = cache_element.value | ||
jax.debug.print(" shape: {}", cache_element.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these really shouldn't be jax.debug.print's because you aren't running them in a jit. You can just print
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/max_utils.py
Outdated
@@ -87,6 +87,26 @@ def summarize_size_from_pytree(params): | |||
return num_params, num_bytes, num_bytes / num_params | |||
|
|||
|
|||
def calculate_total_params_across_chip(params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry what does this mean? I wonder if there is a clearer name (and possibly a docstring?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added docstring
MaxText/max_utils.py
Outdated
@@ -87,6 +87,26 @@ def summarize_size_from_pytree(params): | |||
return num_params, num_bytes, num_bytes / num_params | |||
|
|||
|
|||
def calculate_total_params_across_chip(params): | |||
def calculate_sizes_per_chip(arr): | |||
return [np.prod(shard.data.shape) for shard in arr.addressable_shards] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.prod(shard.data.shape) could be shard.data.size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
MaxText/max_utils.py
Outdated
sizes_across_chips = jax.tree_util.tree_map(calculate_sizes_per_chip, params) | ||
num_chips = len(sizes_across_chips) | ||
total_sizes_across_chips = jax.tree_util.tree_reduce(lambda x, y: x + y, sizes_across_chips) | ||
sizes_per_chip = total_sizes_across_chips / num_chips |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is INCREDIBLY paranoid code because we're SPMD so calculating any chip is adequate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But no worries if you're paranoid!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a pass and changed it to a normal paranoid level. But let me share some context why I had this incredibly paranoid code the first place.
If you recall a couple of weeks ago, I was mentioning there was some memory issues that affecting the JetStream serving batch size. One of the issue came down to the prefill_result had an initiation for both prefill cache, and generate cache. The prefill cache was properly sharded, where there was no sharding constraint applied on the generate cache, thus the generate cache created a copy on all TPU chips.
This was confirmed with the utils in this PR. For example, see below ar_key's physical_sizes/bytes versus prefill's:
cached_ar_key:
shape: (1024, 32, 1, 128)
sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec())
total_logical_sizes: 4194304
total_logical_bytes: 8388608
n_chips: 8
total_physical_sizes_across_chips: 33554432
total_physical_bytes_across_chip: 67108864
cached_ar_value:
...... (same as cached_ar_key)
cached_prefill_key:
shape: (1024, 32, 1, 128)
sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec(None, 'tensor'))
total_logical_sizes: 4194304
total_logical_bytes: 8388608
n_chips: 8
total_physical_sizes_across_chips: 4194304
total_physical_bytes_across_chip: 8388608
cached_prefill_value:
...... (same as cached_prefill_key)
logits:
shape: (1, 1, 32000)
sharding: NamedSharding(mesh=Mesh('data': 1, 'fsdp': 1, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 8, 'autoregressive': 1), spec=PartitionSpec())
total_logical_sizes: 32000
total_logical_bytes: 128000
n_chips: 8
total_physical_sizes_across_chips: 256000
total_physical_bytes_across_chip: 1024000
MaxText/max_utils.py
Outdated
return total_sizes_across_chips, sizes_per_chip, num_chips | ||
|
||
|
||
def calculate_total_bytes_across_chip(params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some similar feedback as above here.
No description provided.