Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASI-NN] Add support for a ONNXruntime backend using ort #7691

Merged
merged 8 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ jobs:
fi

# Build and test all features
- run: ./ci/run-tests.sh --locked
- run: ./ci/run-tests.sh ${{ matrix.extra_features }} --locked
env:
RUST_BACKTRACE: 1

Expand Down
104 changes: 102 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions ci/build-test-matrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ const array = [
"os": "ubuntu-latest",
"name": "Test Linux x86_64",
"filter": "linux-x64",
"isa": "x64"
"isa": "x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "ubuntu-latest",
Expand All @@ -57,18 +58,21 @@ const array = [
{
"os": "macos-latest",
"name": "Test macOS x86_64",
"filter": "macos-x64"
"filter": "macos-x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "macos-14",
"name": "Test macOS arm64",
"filter": "macos-arm64",
"target": "aarch64-apple-darwin"
"target": "aarch64-apple-darwin",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "windows-latest",
"name": "Test Windows MSVC x86_64",
"filter": "windows-x64"
"filter": "windows-x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "windows-latest",
Expand All @@ -85,7 +89,7 @@ const array = [
"qemu_target": "aarch64-linux-user",
"name": "Test Linux arm64",
"filter": "linux-arm64",
"isa": "aarch64"
"isa": "aarch64",
},
{
"os": "ubuntu-latest",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Load model from preloaded directory named "fixtures" which contains a model.[bin|xml] mobilenet model.
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
.build_from_cache("fixtures")?;
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context()?;
Expand Down
55 changes: 55 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx").unwrap();
println!("[ONNX] Read model, size in bytes: {}", model.len());

let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;

let mut context = graph.init_execution_context()?;
println!(
"[ONNX] Created wasi-nn execution context with ID: {}",
context
);

// Prepare WASI-NN tensor - Tensor data is always a bytes vector
// Load a tensor that precisely matches the graph input tensor
let data = fs::read("fixture/tensor.bgr").unwrap();
println!("[ONNX] Read input tensor, size in bytes: {}", data.len());
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?;

// Execute the inferencing
context.compute()?;
println!("[ONNX] Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..])?;
println!(
"[ONNX] Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);

Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results.
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
4 changes: 4 additions & 0 deletions crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ openvino = { version = "0.6.0", features = [
"runtime-linking",
], optional = true }

ort = { version = "2.0.0-rc.0", default-features = false, features = ["copy-dylibs", "download-binaries"], optional = true }

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = [
Expand All @@ -51,5 +53,7 @@ wasmtime = { workspace = true, features = ["cranelift"] }
default = ["openvino"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
onnx = ["dep:ort"]
# winml is only available on Windows 10 1809 and later.
winml = ["dep:windows"]