Skip to content

Commit

Permalink
prtest:full fix running WASI-NN ONNX tests across arch os
Browse files Browse the repository at this point in the history
Signed-off-by: David Justice <david@devigned.com>
  • Loading branch information
devigned committed Mar 13, 2024
1 parent af7c140 commit 08c41f2
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 30 deletions.
93 changes: 65 additions & 28 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,26 @@
//! - that WinML is available
//! - that some ML model artifacts can be downloaded and cached.

use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};
#[cfg(all(feature = "winml", target_os = "windows"))]
use anyhow::Result;
use std::{env, path::Path, path::PathBuf, process::Command, sync::Mutex};

#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
use {
anyhow::{anyhow, Context},
std::fs,
};

#[cfg(all(
feature = "onnx",
any(target_os = "linux", target_os = "windows", target_os = "macos")
))]
use {anyhow::Context, std::fs};

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
use windows::AI::MachineLearning::{LearningModelDevice, LearningModelDeviceKind};

/// Return the directory in which the test artifacts are stored.
Expand Down Expand Up @@ -37,25 +54,42 @@ macro_rules! check_test {

/// Return `Ok` if all checks pass.
pub fn check() -> Result<()> {
#[cfg(feature = "openvino")]
#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
{
check_openvino_is_installed()?;
check_openvino_artifacts_are_available()?;
}
#[cfg(feature = "onnx")]
#[cfg(all(
feature = "onnx",
any(target_os = "linux", target_os = "windows", target_os = "macos")
))]
{
check_onnx_artifacts_are_available()?;
}
#[cfg(all(feature = "winml", target_os = "windows"))]
#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
{
check_winml_is_available()?;
check_winml_artifacts_are_available()?;
}
Ok(())
}

/// Protect `check_openvino_artifacts_are_available` from concurrent access;
/// when running tests in parallel, we want to avoid two threads attempting to
/// create the same directory or download the same file.
#[allow(dead_code)]
static ARTIFACTS: Mutex<()> = Mutex::new(());

/// Return `Ok` if we find a working OpenVINO installation.
#[cfg(feature = "openvino")]
#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
fn check_openvino_is_installed() -> Result<()> {
match std::panic::catch_unwind(|| println!("> found openvino version: {}", openvino::version()))
{
Expand All @@ -64,27 +98,13 @@ fn check_openvino_is_installed() -> Result<()> {
}
}

#[cfg(all(feature = "winml", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
"> WinML learning device is available: {:?}",
LearningModelDevice::Create(LearningModelDeviceKind::Default)
)
}) {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("WinML learning device is not available: {:?}", e)),
}
}

/// Protect `check_openvino_artifacts_are_available` from concurrent access;
/// when running tests in parallel, we want to avoid two threads attempting to
/// create the same directory or download the same file.
static ARTIFACTS: Mutex<()> = Mutex::new(());

/// Return `Ok` if we find the cached MobileNet test artifacts; this will
/// download the artifacts if necessary.
#[cfg(feature = "openvino")]
#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
fn check_openvino_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
const BASE_URL: &str =
Expand All @@ -110,7 +130,23 @@ fn check_openvino_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(feature = "onnx")]
#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
"> WinML learning device is available: {:?}",
LearningModelDevice::Create(LearningModelDeviceKind::Default)
)
}) {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("WinML learning device is not available: {:?}", e)),
}
}

#[cfg(all(
feature = "onnx",
any(target_os = "linux", target_os = "windows", target_os = "macos")
))]
fn check_onnx_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();

Expand Down Expand Up @@ -141,7 +177,7 @@ fn check_onnx_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_os = "windows"))]
#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
let artifacts_dir = artifacts_dir();
Expand All @@ -167,6 +203,7 @@ fn check_winml_artifacts_are_available() -> Result<()> {
}

/// Retrieve the bytes at the `from` URL and place them in the `to` file.
#[allow(dead_code)]
fn download(from: &str, to: &Path) -> anyhow::Result<()> {
let mut curl = Command::new("curl");
curl.arg("--location").arg(from).arg("--output").arg(to);
Expand Down
7 changes: 5 additions & 2 deletions crates/wasi-nn/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ fn nn_image_classification_winml() {
)]
#[test]
fn nn_image_classification_onnx() {
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default());
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false).unwrap()
#[cfg(feature = "onnx")]
{
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default());
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false).unwrap()
}
}

0 comments on commit 08c41f2

Please sign in to comment.