Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InferenceSlicer] - add segmentation models support #678

Closed
SkalskiP opened this issue Dec 15, 2023 · 10 comments
Closed

[InferenceSlicer] - add segmentation models support #678

SkalskiP opened this issue Dec 15, 2023 · 10 comments
Assignees
Labels
enhancement New feature or request Q2.2024 Tasks planned for execution in Q2 2024.

Comments

@SkalskiP
Copy link
Collaborator

SkalskiP commented Dec 15, 2023

Description

Currently, sv.InferenceSlicer supports only object detection models. Adding support for instance segmentation would require the following changes:

  • The sv.InferenceSlicer uses Non-Max Suppression (NMS) to sift out duplicate detections at the tile intersection. At the moment, Supervision only has a box-based NMS. A segmentation-based NMS would be almost ideal, the only change would be to replace the box_iou_batch with a new mask_iou_batch.
  • A segmentation-based NMS must be plugged into sv.InferenceSlicer. At this point, we would have to check whether the detections have masks. And if so, use the new NMS.

API

# create
def mask_iou_batch(boxes_true: np.ndarray, boxes_detection: np.ndarray) -> np.ndarray:
    pass

# rename non_max_suppression -> box_non_max_suppression

# create
def mask_non_max_suppression(predictions: np.ndarray, iou_threshold: float = 0.5) -> np.ndarray:
    pass

# change InferenceSlicer

Usage example

import cv2
import supervision as sv
from ultralytics import YOLO

image = cv2.image = cv2.imread(<SOURCE_IMAGEPATH>)
model = YOLO("yolov8x-seg.pt")

def callback(image_slice: np.ndarray) -> sv.Detections:
    result = model(image_slice)[0]
    return sv.Detections.from_ultralytics(result)

slicer = sv.InferenceSlicer(
    callback=callback,
    slice_wh=(512, 512),
    iou_threshold=0.5,
)

detections = slicer(image)

Additional

  • Note: Please share a Google Colab with minimal code to test the new feature. We know it's additional work, but it will definitely speed up the review process. Each change must be tested by the reviewer. Setting up a local environment to do this is time-consuming. Please ensure that Google Colab can be accessed without any issues (make it public). Thank you! 🙏🏻
@SkalskiP SkalskiP added the enhancement New feature or request label Dec 15, 2023
@SkalskiP SkalskiP self-assigned this Dec 28, 2023
@SkalskiP SkalskiP added the Q1.2024 Tasks planned for execution in Q1 2024. label Jan 18, 2024
@SkalskiP SkalskiP changed the title Enhance SAHI with support for segmentation models [InferenceSlicer] - add segmentation models support Jan 26, 2024
@SkalskiP SkalskiP removed their assignment Jan 26, 2024
@AdonaiVera
Copy link
Contributor

Hi @SkalskiP ,

I've been working on integrating advanced mask handling capabilities into the InferenceSlicer class,
including mask_non_max_suppression and mask_iou_batch. As part of these enhancements, I've also
renamed non_max_suppression to box_non_max_suppression for clarity.

I've hit a snag while trying to merge detection objects with variable-sized masks in the Detections.merge
function. The issue arises because numpy arrays require uniform dimensions for stacking, but our masks,
being tied to detected objects, vary in size. This discrepancy causes the stacking operation to fail.

I'm considering a few approaches to address this, such as resizing, padding, or storing masks individually,
but each has its trade-offs regarding efficiency, complexity, and fidelity to the original mask shapes.

I'd appreciate your thoughts on the best path forward. Should we prioritize memory efficiency, ease of
implementation, or mask integrity? Or is there an alternative solution you'd recommend?

Thank you for your guidance. 👋
Ado,

@SkalskiP
Copy link
Collaborator Author

Hi @AdonaiVera 👋🏻 We expect the exact dimensions of masks (width and height) to be the same as the source image. This is our approach for now. We are, of course, aware of potential memory and speed optimizations.

If your masks are variable-sized, you should pad them and make them all equally sized.

@AdonaiVera
Copy link
Contributor

Yes @SkalskiP, Its true masks match the dimensions of the slice's image. However, the Inference Slicer creates slices of variable sizes, particularly at image boundaries, depending on the slice_wh and the actual dimension of the image, leading to masks that don't all share the same dimensions. This variability creates a challenge when we try to merge these masks using np.vstack, as it requires uniform dimensions.

def stack_or_none(name: str):
    if all(d.__getattribute__(name) is None for d in detections_list):
        return None
    if any(d.__getattribute__(name) is None for d in detections_list):
        raise ValueError("All or none of the '{}' fields must be None".format(name))
    return np.vstack([d.__getattribute__(name) for d in detections_list]) if name == "mask" else np.hstack([d.__getattribute__(name) for d in detections_list])

mask = stack_or_none("mask")

I am considering two potential solutions to address the issue at hand. The first approach is to scale masks to match the largest dimensions and then resize them as needed (As you suggest, but we need to add two resizing steps). The second approach is to store the masks in a different format that can accommodate variable sizes.

I would appreciate your thoughts on these approaches or any other suggestions you might have. I want to make sure that I fully understand the problem before creating the PR.

Thank you 🚀

@SkalskiP
Copy link
Collaborator Author

SkalskiP commented Feb 1, 2024

@AdonaiVera oooooh! Now I understand what you mean. I did not foresee this complexity when dissecting this task. I'm sorry. This is quite obvious in hindsight. 🤦🏻‍♂️

First of all, I'm happy to limit the scope of this task to just the segmentation of NMS. This would still be a big win for supervision. Then, we could work on segmentation slicing logic separately. Let me know what you think.

Even if you opt to continue, we should split work into two PRs, separate for segmentation NMS and slicing logic. Please open it if you can. 🙏🏻

As for the potential solution, here is what I think:

Some time ago, one of the users opened this issue. There was no follow-up, so we closed it, but I think a particular group of models expects images in specific shapes, for example, 1024x1024, and if that's the case, those models will complain. We can solve both problems with a single solution.

- load image
- slice the image into NxN tiles
- surround smaller size slices (the ones close to the edges) with a letterbox so that all tiles are NxN
- loop over slices
   - run inference
- update box coordinate values to match the image coordinate system, not the slice coordinate system
- pad masks to match the image coordinate system, not the slice coordinate system
- merge detections

@SkalskiP
Copy link
Collaborator Author

SkalskiP commented Feb 2, 2024

@AdonaiVera let me know what you decided 🙏

@AdonaiVera
Copy link
Contributor

Hi @SkalskiP 👋

I like the idea of splitting the task into two different PRs. This will help us to organize the results better. The first PR will focus on Segmentation NMS, and the second one will focus on segmentation slicing. I will work on the NMS feature first. Once I finish it, I will open the PR.

Regarding your solution, I like it a lot as it will solve the issue of different slide sizes. However, the user cannot add the size of the slide as an input. It has to be inferred depending on the size of the image to guarantee the same size in each slide. I can test this idea and see how it works. 💪

@SkalskiP
Copy link
Collaborator Author

SkalskiP commented Feb 2, 2024

@AdonaiVera I'm waiting for Segmentation NMS PR! 🙏🏻

@AdonaiVera
Copy link
Contributor

Hi @SkalskiP 👋
I have implemented NMS for segmentation. When you have a chance, can you please check it? After we make all the corrections, I will continue with the InferenceSlicer. 🚀 💪

@AdonaiVera
Copy link
Contributor

AdonaiVera commented Feb 9, 2024

Hi @SkalskiP 👋
I hope you're well. I've made some updates to how we handle the inferenceSlicer function.

- load image ✅ 
- slice the image into NxN tiles ✅ 
- surround smaller size slices (the ones close to the edges) with a letterbox so that all tiles are NxN
- loop over slices ✅ 
   - run inference ✅ 
- update box coordinate values to match the image coordinate system, not the slice coordinate system ✅ 
- pad masks to match the image coordinate system, not the slice coordinate system
- merge detections

First, I integrated a conditional to pad the images that were smaller than the slice size (corners). After, I created the function '_apply_padding_to_slice' to apply padding to a slice using the resizing method letterbox.

However, aligning mask data to the full image is more complex than bounding boxes because masks cover each pixel.

Here's what I'm thinking:
Instead of saving masks just for the image slices, we save them as if they're part of the whole image. This means most of the masks would be empty, except for the part we're interested in. This makes it easier to combine all the masks, but it's going to be more memory because the mask's size in each detection will be the same as the original image.

Another option is to use sparse mask storage. The idea should be that while processing each slice, we store the mask in a "sparse" format, meaning we only keep track of the parts of the mask that aren't empty. This saves a lot of memory. Later, when we merge everything, we convert this sparse mask into a regular one that covers the entire image.

I would like to hear your thoughts. 🥷

UPDATE:
I'm finishing the second approach and will create the PR here for us to discuss. 💪

@SkalskiP SkalskiP added Q2.2024 Tasks planned for execution in Q2 2024. and removed Q1.2024 Tasks planned for execution in Q1 2024. labels Apr 8, 2024
@SkalskiP
Copy link
Collaborator Author

SkalskiP commented Jun 6, 2024

released via supervision-0.21.0

@SkalskiP SkalskiP closed this as completed Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Q2.2024 Tasks planned for execution in Q2 2024.
Projects
Development

No branches or pull requests

2 participants