-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add an initial implemenation for onnxruntime backend of wasi-nn
Signed-off-by: David Justice <david@devigned.com>
- Loading branch information
Showing
20 changed files
with
11,161 additions
and
26 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
crates/test-programs/src/bin/nn_image_classification_onnx.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
use anyhow::Result; | ||
use image::imageops::Triangle; | ||
use image::io::Reader; | ||
use image::{DynamicImage, RgbImage}; | ||
use ndarray::{Array, Dim}; | ||
use std::fs; | ||
use std::io::BufRead; | ||
use wasi_nn::*; | ||
|
||
pub fn main() -> Result<()> { | ||
// Load the ONNX model - SqueezeNet 1.1-7 | ||
// Full details: https://github.com/onnx/models/tree/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/squeezenet | ||
let model = fs::read("fixture/model.onnx").unwrap(); | ||
println!("[ONNX] Read model, size in bytes: {}", model.len()); | ||
|
||
let graph = | ||
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?; | ||
|
||
let mut context = graph.init_execution_context()?; | ||
println!( | ||
"[ONNX] Created wasi-nn execution context with ID: {}", | ||
context | ||
); | ||
|
||
// Prepare WASI-NN tensor - Tensor data is always a bytes vector | ||
let data = image_to_tensor("fixture/mushroom.png".to_string(), 224, 224); | ||
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &data)?; | ||
println!("[ONNX] Set input tensor"); | ||
|
||
// Execute the inferencing | ||
context.compute()?; | ||
println!("[ONNX] Executed graph inference"); | ||
|
||
// Retrieve the output. | ||
let mut output_buffer = vec![0f32; 1000]; | ||
context.get_output(0, &mut output_buffer[..])?; | ||
println!("[ONNX] Get output tensor"); | ||
|
||
let output_shape = [1, 1000, 1, 1]; | ||
let output_tensor = Array::from_shape_vec(output_shape, output_buffer).unwrap(); | ||
|
||
// Post-Processing requirement: compute softmax to inferencing output | ||
let exp_output = output_tensor.mapv(|x| x.exp()); | ||
let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1)); | ||
let softmax_output = exp_output / &sum_exp_output; | ||
|
||
let mut sorted = softmax_output | ||
.axis_iter(ndarray::Axis(1)) | ||
.enumerate() | ||
.into_iter() | ||
.map(|(i, v)| (i, v[Dim([0, 0, 0])])) | ||
.collect::<Vec<(_, _)>>(); | ||
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); | ||
|
||
// Load SquezeNet 1000 labels used for classification | ||
let labels = fs::read("fixture/squeezenet1.1-7.txt").unwrap(); | ||
let class_labels: Vec<String> = labels.lines().map(|line| line.unwrap()).collect(); | ||
println!( | ||
"[ONNX] Read squeezenet Labels, # of labels: {}", | ||
class_labels.len() | ||
); | ||
|
||
for (index, probability) in sorted.iter().take(3) { | ||
println!( | ||
"[ONNX] Index: {} - Probability: {}", | ||
class_labels[*index], probability | ||
); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
// Take the image located at 'path', open it, resize it to height x width, and then converts | ||
// the pixel precision to FP32. The resulting BGR pixel vector is then returned. | ||
fn image_to_tensor(path: String, height: u32, width: u32) -> Vec<u8> { | ||
let pixels = Reader::open(path).unwrap().decode().unwrap(); | ||
let dyn_img: DynamicImage = pixels.resize_exact(width, height, Triangle); | ||
let bgr_img: RgbImage = dyn_img.to_rgb8(); | ||
|
||
// Get an array of the pixel values | ||
let raw_u8_arr: &[u8] = &bgr_img.as_raw()[..]; | ||
|
||
// Create an array to hold the f32 value of those pixels | ||
let bytes_required = raw_u8_arr.len() * 4; | ||
let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required]; | ||
|
||
// Normalizing values for the model | ||
let mean = [0.485, 0.456, 0.406]; | ||
let std = [0.229, 0.224, 0.225]; | ||
|
||
// Read the number as a f32 and break it into u8 bytes | ||
for i in 0..raw_u8_arr.len() { | ||
let u8_f32: f32 = raw_u8_arr[i] as f32; | ||
let rgb_iter = i % 3; | ||
|
||
// Normalize the pixel | ||
let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter]; | ||
|
||
// Convert it to u8 bytes and write it with new shape | ||
let u8_bytes = norm_u8_f32.to_ne_bytes(); | ||
for j in 0..4 { | ||
u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j]; | ||
} | ||
} | ||
|
||
return u8_f32_arr; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
244 changes: 244 additions & 0 deletions
244
crates/wasi-nn/examples/classification-component-onnx/Cargo.lock
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.