Skip to content

Commit

Permalink
add an initial implemenation for onnxruntime backend of wasi-nn
Browse files Browse the repository at this point in the history
Signed-off-by: David Justice <david@devigned.com>
  • Loading branch information
devigned committed Feb 8, 2024
1 parent 2c0e528 commit c438db1
Show file tree
Hide file tree
Showing 20 changed files with 11,161 additions and 26 deletions.
229 changes: 226 additions & 3 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions crates/test-programs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ futures = { workspace = true, default-features = false, features = ['alloc'] }
url = { workspace = true }
sha2 = "0.10.2"
base64 = "0.21.0"
ndarray = "0.15.3"
image = { version = "0.24.6", default-features = false, features = ["png"] }
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Load model from preloaded directory named "fixture" which contains a model.[bin|xml] mobilenet model.
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
.build_from_cache("fixture")?;
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context()?;
Expand Down
107 changes: 107 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use anyhow::Result;
use image::imageops::Triangle;
use image::io::Reader;
use image::{DynamicImage, RgbImage};
use ndarray::{Array, Dim};
use std::fs;
use std::io::BufRead;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Load the ONNX model - SqueezeNet 1.1-7
// Full details: https://github.com/onnx/models/tree/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/squeezenet
let model = fs::read("fixture/model.onnx").unwrap();
println!("[ONNX] Read model, size in bytes: {}", model.len());

let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;

let mut context = graph.init_execution_context()?;
println!(
"[ONNX] Created wasi-nn execution context with ID: {}",
context
);

// Prepare WASI-NN tensor - Tensor data is always a bytes vector
let data = image_to_tensor("fixture/mushroom.png".to_string(), 224, 224);
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &data)?;
println!("[ONNX] Set input tensor");

// Execute the inferencing
context.compute()?;
println!("[ONNX] Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..])?;
println!("[ONNX] Get output tensor");

let output_shape = [1, 1000, 1, 1];
let output_tensor = Array::from_shape_vec(output_shape, output_buffer).unwrap();

// Post-Processing requirement: compute softmax to inferencing output
let exp_output = output_tensor.mapv(|x| x.exp());
let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
let softmax_output = exp_output / &sum_exp_output;

let mut sorted = softmax_output
.axis_iter(ndarray::Axis(1))
.enumerate()
.into_iter()
.map(|(i, v)| (i, v[Dim([0, 0, 0])]))
.collect::<Vec<(_, _)>>();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

// Load SquezeNet 1000 labels used for classification
let labels = fs::read("fixture/squeezenet1.1-7.txt").unwrap();
let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect();
println!(
"[ONNX] Read squeezenet Labels, # of labels: {}",
class_labels.len()
);

for (index, probability) in sorted.iter().take(3) {
println!(
"[ONNX] Index: {} - Probability: {}",
class_labels[*index], probability
);
}

Ok(())
}

// Take the image located at 'path', open it, resize it to height x width, and then converts
// the pixel precision to FP32. The resulting BGR pixel vector is then returned.
fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> {
let pixels = Reader::open(path).unwrap().decode().unwrap();
let dyn_img: DynamicImage = pixels.resize_exact(width, height, Triangle);
let bgr_img: RgbImage = dyn_img.to_rgb8();

// Get an array of the pixel values
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..];

// Create an array to hold the f32 value of those pixels
let bytes_required = raw_u8_arr.len() * 4;
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];

// Normalizing values for the model
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];

// Read the number as a f32 and break it into u8 bytes
for i in 0..raw_u8_arr.len() {
let u8_f32: f32 = raw_u8_arr[i] as f32;
let rgb_iter = i % 3;

// Normalize the pixel
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];

// Convert it to u8 bytes and write it with new shape
let u8_bytes = norm_u8_f32.to_ne_bytes();
for j in 0..4 {
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
}
}

return u8_f32_arr;
}
6 changes: 5 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ wasmtime = { workspace = true, features = ["component-model", "runtime"] }
tracing = { workspace = true }
openvino = { version = "0.6.0", features = ["runtime-linking"] }
thiserror = { workspace = true }
ort = { version = "2.0.0-rc.0" }
lazy_static = { version = "1.4" }
bytes = { version = "1.5" }
ndarray = { version = "0.15" }

[build-dependencies]
walkdir = { workspace = true }

[dev-dependencies]
cap-std = { workspace = true }
test-programs-artifacts = { workspace = true }
wasmtime-wasi = { workspace = true, features = ["sync"] }
wasmtime-wasi = { workspace = true, features = ["sync", "preview2"] }
wasmtime = { workspace = true, features = ["cranelift"] }
244 changes: 244 additions & 0 deletions crates/wasi-nn/examples/classification-component-onnx/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c438db1

Please sign in to comment.