Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: ultralytics/ultralytics
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v8.3.84
Choose a base ref
...
head repository: ultralytics/ultralytics
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: v8.3.85
Choose a head ref
  • 2 commits
  • 3 files changed
  • 3 contributors

Commits on Mar 6, 2025

  1. Cleanup and fix ONNX segment example (#19551)

    Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
    Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
    Y-T-G and UltralyticsAssistant authored Mar 6, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    082a58d View commit details
  2. ultralytics 8.3.85 TensoRT export max_shape fix (#19541)

    Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
    Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
    Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
    Y-T-G and glenn-jocher authored Mar 6, 2025

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    23a9014 View commit details
Showing with 68 additions and 170 deletions.
  1. +65 βˆ’167 examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
  2. +1 βˆ’1 ultralytics/__init__.py
  3. +2 βˆ’2 ultralytics/engine/exporter.py
232 changes: 65 additions & 167 deletions examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# Ultralytics πŸš€ AGPL-3.0 License - https://ultralytics.com/license

import argparse
from typing import List, Tuple, Union

import cv2
import numpy as np
import onnxruntime as ort
import torch
import torch.nn.functional as F

import ultralytics.utils.ops as ops
from ultralytics.engine.results import Results
@@ -18,238 +16,138 @@
class YOLOv8Seg:
"""YOLOv8 segmentation model."""

def __init__(self, onnx_model, conf_threshold=0.4):
def __init__(self, onnx_model, conf=0.25, iou=0.7, imgsz=640):
"""
Initializes the object detection model using an ONNX model.
Args:
onnx_model (str): Path to the ONNX model file.
conf_threshold (float, optional): Confidence threshold for detections. Defaults to 0.4.
conf (float, optional): Confidence threshold for detections. Defaults to 0.25.
iou (float, optional): IoU threshold for NMS. Defaults to 0.7.
imgsz (int | Tuple): Input image size of the model.
Attributes:
session (ort.InferenceSession): ONNX Runtime session for running inference.
ndtype (numpy.dtype): Data type for model input (FP16 or FP32).
model_height (int): Height of the model's input image.
model_width (int): Width of the model's input image.
classes (list): List of class names from the COCO dataset.
device (str): Specifies whether inference runs on CPU or GPU.
conf_threshold (float): Confidence threshold for filtering detections.
session (ort.InferenceSession): ONNX Runtime session.
imgsz (Tuple): Input image size of the model.
classes (dict): Class mappings from the COCO dataset.
conf (float): Confidence threshold for filtering detections.
iou (float): IoU threshold used by NMS.
"""
# Build Ort session
self.session = ort.InferenceSession(
onnx_model,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
if ort.get_device() == "GPU"
if torch.cuda.is_available()
else ["CPUExecutionProvider"],
)

# Numpy dtype: support both FP32 and FP16 onnx model
self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single

# Get model width and height(YOLOv8-seg only has one input)
self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:]

# Load COCO class names
self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
self.classes = yaml_load(check_yaml("coco8.yaml"))["names"]
self.conf = conf
self.iou = iou

# Device
self.device = "cuda:0" if ort.get_device().lower() == "gpu" else "cpu"

# Confidence
self.conf_threshold = conf_threshold

def __call__(self, im0):
def __call__(self, img):
"""
Runs inference on the input image using the ONNX model.
Args:
im0 (numpy.ndarray): The original input image in BGR format.
img (numpy.ndarray): The original input image in BGR format.
Returns:
list: Processed detection results after post-processing.
Example:
>>> detector = Model("yolov8.onnx")
>>> results = detector(image) # Runs inference and returns detections.
"""
# Pre-process
processed_image = self.preprocess(im0)

# Ort inference
predictions = self.session.run(None, {self.session.get_inputs()[0].name: processed_image})

# Post-process
return self.postprocess(im0, processed_image, predictions)

def preprocess(self, image, new_shape: Union[Tuple, List] = (640, 640)):
"""
Preprocesses the input image before feeding it into the model.
prep_img = self.preprocess(img, self.imgsz)
outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})
return self.postprocess(img, prep_img, outs)

Args:
image (np.ndarray): The input image in BGR format.
new_shape (Tuple or List, optional): The target shape for resizing. Defaults to (640, 640).
Returns:
np.ndarray: Preprocessed image ready for model inference.
def letterbox(self, img, new_shape=(640, 640)):
"""Resizes and reshapes images while maintaining aspect ratio by adding padding, suitable for YOLO models."""
shape = img.shape[:2] # current shape [height, width]

Example:
>>> processed_img = model.preprocess(image)
"""
image, _, _ = self.__resize_and_pad_image(image=image, new_shape=new_shape)
image = self.__reshape_image(image=image)
return image[None] if len(image.shape) == 3 else image

def __reshape_image(self, image: np.ndarray) -> np.ndarray:
"""
Reshapes the image by changing its layout and normalizing pixel values.
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

Args:
image (np.ndarray): The image to be reshaped.
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding

Returns:
np.ndarray: Reshaped and normalized image.
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))

Example:
>>> reshaped_img = model.__reshape_image(image)
"""
image = image.transpose([2, 0, 1])
image = image[np.newaxis, ...]
return np.ascontiguousarray(image).astype(np.float32) / 255
return img

def __resize_and_pad_image(
self, image=np.ndarray, new_shape: Union[Tuple, List] = (640, 640), color: Union[Tuple, List] = (114, 114, 114)
):
def preprocess(self, img, new_shape):
"""
Resizes and pads the input image while maintaining the aspect ratio.
Preprocesses the input image before feeding it into the model.
Args:
image (np.ndarray): The input image.
new_shape (Tuple or List, optional): Target shape (width, height). Defaults to (640, 640).
color (Tuple or List, optional): Padding color. Defaults to (114, 114, 114).
img (np.ndarray): The input image in BGR format.
new_shape (Tuple or List, optional): The target shape for resizing. Defaults to (640, 640).
Returns:
Tuple[np.ndarray, float, float]: The resized image along with padding values.
Example:
>>> resized_img, dw, dh = model.__resize_and_pad_image(image)
np.ndarray: Preprocessed image ready for model inference.
"""
shape = image.shape[:2] # original image shape

if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
ratio = min(new_shape[0] / shape[1], new_shape[1] / shape[0])

new_unpad = int(round(shape[1] * ratio)), int(round(shape[0] * ratio))
delta_width, delta_height = new_shape[0] - new_unpad[0], new_shape[1] - new_unpad[1]
img = self.letterbox(img, new_shape)
img = img[..., ::-1].transpose([2, 0, 1])[None]
img = np.ascontiguousarray(img)
img = img.astype(np.float32) / 255
return img

# Divide padding into 2 sides
delta_width /= 2
delta_height /= 2

image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] == new_unpad else image

top, bottom = int(round(delta_height - 0.1)), int(round(delta_height + 0.1))
left, right = int(round(delta_width - 0.1)), int(round(delta_width + 0.1))
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return image, delta_width, delta_height

def postprocess(self, image, processed_image, predictions):
def postprocess(self, img, prep_img, outs):
"""
Post-processes model predictions to extract meaningful results.
Args:
image (np.ndarray): The original input image.
processed_image (np.ndarray): The preprocessed image used for inference.
predictions (list): Model output predictions.
img (np.ndarray): The original input image.
prep_img (np.ndarray): The preprocessed image used for inference.
outs (list): Model outputs.
Returns:
list: Processed detection results.
Example:
>>> results = model.postprocess(image, processed_image, predictions)
"""
torch_tensor_predictions = [torch.from_numpy(output) for output in predictions]
torch_tensor_boxes_confidence_category_predictions = torch_tensor_predictions[0]
masks_predictions_tensor = torch_tensor_predictions[1].to(self.device)

nms_boxes_confidence_category_predictions_tensor = ops.non_max_suppression(
torch_tensor_boxes_confidence_category_predictions,
conf_thres=self.conf_threshold,
nc=len(self.classes),
agnostic=False,
max_det=100,
max_time_img=0.001,
max_nms=1000,
)
preds, protos = [torch.from_numpy(p) for p in outs]
preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))

results = []
for idx, predictions in enumerate(nms_boxes_confidence_category_predictions_tensor):
predictions = predictions.to(self.device)
masks = self.__process_mask(
masks_predictions_tensor[idx],
predictions[:, 6:],
predictions[:, :4],
processed_image.shape[2:],
upsample=True,
) # HWC
predictions[:, :4] = ops.scale_boxes(processed_image.shape[2:], predictions[:, :4], image.shape)
results.append(Results(image, path="", names=self.classes, boxes=predictions[:, :6], masks=masks))
for i, pred in enumerate(preds):
pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)
masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])
results.append(Results(img, path="", names=self.classes, boxes=pred[:, :6], masks=masks))

return results

def __process_mask(self, protos, masks_in, bboxes, shape, upsample=False):
def process_mask(self, protos, masks_in, bboxes, shape):
"""
Processes segmentation masks from the model output.
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
Args:
protos (torch.Tensor): The prototype mask predictions from the model.
masks_in (torch.Tensor): The raw mask predictions.
bboxes (torch.Tensor): Bounding boxes for the detected objects.
shape (Tuple): Target shape for mask resizing.
upsample (bool, optional): Whether to upscale masks to match the original image size. Defaults to False.
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
shape (Tuple): The size of the input image (h,w).
Returns:
torch.Tensor: Processed binary masks.
Example:
>>> masks = model.__process_mask(protos, masks_in, bboxes, shape, upsample=True)
masks (torch.Tensor): The returned masks with dimensions [h, w, n].
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
width_ratio = mw / iw
height_ratio = mh / ih

downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= width_ratio
downsampled_bboxes[:, 2] *= width_ratio
downsampled_bboxes[:, 3] *= height_ratio
downsampled_bboxes[:, 1] *= height_ratio

masks = ops.crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5).to(self.device)
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
masks = ops.scale_masks(masks[None], shape)[0] # CHW
masks = ops.crop_mask(masks, bboxes) # CHW
return masks.gt_(0.0)


if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--iou", type=float, default=0.7, help="NMS IoU threshold")
args = parser.parse_args()

# Build model
model = YOLOv8Seg(args.model, args.conf)

# Read image by OpenCV
model = YOLOv8Seg(args.model, args.conf, args.iou)
img = cv2.imread(args.source)
img = cv2.resize(img, (640, 640)) # Can be changed based on your models expected size

# Inference
results = model(img)

cv2.imshow("Segmented Image", results[0].plot())
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics πŸš€ AGPL-3.0 License - https://ultralytics.com/license

__version__ = "8.3.84"
__version__ = "8.3.85"

import os

4 changes: 2 additions & 2 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
@@ -867,7 +867,7 @@ def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
# Engine builder
builder = trt.Builder(logger)
config = builder.create_builder_config()
workspace = int(self.args.workspace * (1 << 30)) if self.args.workspace is not None else 0
workspace = int((self.args.workspace or 0) * (1 << 30))
if is_trt10 and workspace > 0:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
elif workspace > 0: # TensorRT versions 7, 8
@@ -909,7 +909,7 @@ def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
max_shape = (*shape[:2], *(int(max(1, workspace) * d) for d in shape[2:])) # max input shape
max_shape = (*shape[:2], *(int(max(1, self.args.workspace or 1) * d) for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
config.add_optimization_profile(profile)