Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Logging Mechanisms, Enhance Error Handling, and Improve Documentation #28

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 20 additions & 28 deletions corenet/engine/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
import os
from typing import Dict, List, Optional, Union
Expand All @@ -15,30 +11,28 @@
from corenet.utils.ddp_utils import is_master
from corenet.utils.file_logger import FileLogger

# Define supported torch dtypes for mixed-precision training
str_to_torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}


def autocast_fn(enabled: bool, amp_precision: Optional[str] = "float16"):
def autocast_fn(enabled: bool, amp_precision: Optional[str] = "float16") -> autocast:
"""Enable autocasting for mixed-precision training."""
if enabled:
# If AMP is enabled, ensure that:
# 1. Device is CUDA
# 2. dtype is FLOAT16 or BFLOAT16
if amp_precision not in str_to_torch_dtype:
logger.error(
"For Mixed-precision training, supported dtypes are {}. Got: {}".format(
list(str_to_torch_dtype.keys()), amp_precision
)
raise ValueError(
f"For Mixed-precision training, supported dtypes are {list(str_to_torch_dtype.keys())}. Got: {amp_precision}"
)

if not torch.cuda.is_available():
logger.error("For mixed-precision training, CUDA device is required.")
raise RuntimeError("CUDA device is required for mixed-precision training.")

return autocast(enabled=enabled, dtype=str_to_torch_dtype[amp_precision])
else:
return autocast(enabled=False)


def get_batch_size(x: Union[Tensor, Dict, List]) -> int:
"""Get batch size from tensor, dictionary, or list."""
if isinstance(x, Tensor):
return x.shape[0]
elif isinstance(x, Dict):
Expand All @@ -65,33 +59,35 @@ def log_metrics(
val_ckpt_metric: Optional[float] = None,
val_ema_ckpt_metric: Optional[float] = None,
) -> None:
"""Log training and validation metrics."""
if not isinstance(lrs, list):
lrs = [lrs]
for g_id, lr_val in enumerate(lrs):
log_writer.add_scalar("LR/Group-{}".format(g_id), round(lr_val, 6), epoch)
log_writer.add_scalar(f"LR/Group-{g_id}", round(lr_val, 6), epoch)

log_writer.add_scalar("Common/Best Metric", round(best_metric, 2), epoch)


def get_log_writers(opts: argparse.Namespace, save_location: Optional[str]):
"""Get log writers for various logging mechanisms."""
is_master_node = is_master(opts)

log_writers = []

if not is_master_node:
return log_writers

tensorboard_logging = getattr(opts, "common.tensorboard_logging", False)
if tensorboard_logging and save_location is not None:
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError as e:
logger.log(
except ImportError:
logger.error(
"Unable to import SummaryWriter from torch.utils.tensorboard. Disabling tensorboard logging"
)
SummaryWriter = None

if SummaryWriter is not None:
exp_dir = "{}/tb_logs".format(save_location)
exp_dir = os.path.join(save_location, "tb_logs")
create_directories(dir_path=exp_dir, is_master_node=is_master_node)
log_writers.append(
SummaryWriter(log_dir=exp_dir, comment="Training and Validation logs")
Expand All @@ -102,34 +98,30 @@ def get_log_writers(opts: argparse.Namespace, save_location: Optional[str]):
try:
from corenet.internal.utils.bolt_logger import BoltLogger
except ModuleNotFoundError:
logger.error("Unable to import bolt. Disabling bolt logging")
BoltLogger = None

if BoltLogger is None:
logger.log("Unable to import bolt. Disabling bolt logging")
else:
if BoltLogger is not None:
log_writers.append(BoltLogger())

hub_logging = getattr(opts, "common.hub.logging", False)
if hub_logging:
try:
from corenet.internal.utils.hub_logger import HubLogger
except ModuleNotFoundError:
logger.error("Unable to import hub. Disabling hub logging")
HubLogger = None

if HubLogger is None:
logger.log("Unable to import hub. Disabling hub logging")
else:
if HubLogger is not None:
try:
hub_logger = HubLogger(opts)
except Exception as ex:
logger.log(
f"Unable to initialize hub logger. Disabling hub logging: {ex}"
)
logger.error(f"Unable to initialize hub logger: {ex}. Disabling hub logging")
hub_logger = None
if hub_logger is not None:
log_writers.append(hub_logger)

file_logging = getattr(opts, "common.file_logging")
file_logging = getattr(opts, "common.file_logging", False)
if file_logging and save_location is not None:
log_writers.append(FileLogger(os.path.join(save_location, "stats.pt")))

Expand Down