Skip to content

Commit

Permalink
Stop depending on avrt.dll statically on Windows
Browse files Browse the repository at this point in the history
* Load `avrt.dll` dynamically with `LoadLibraryW`
* Fail with an `AudioThreadPriorityError`
* Ensure thread-safety with a warmup call

See also https://bugzilla.mozilla.org/show_bug.cgi?id=1884214
  • Loading branch information
yjugl committed Mar 13, 2024
1 parent 10c8fc3 commit 783d697
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ libc = "0.2"
version = "0.52"
features = [
"Win32_Foundation",
"Win32_System_Threading",
"Win32_System_LibraryLoader",
]

[target.'cfg(target_os = "linux")'.dependencies]
Expand Down
247 changes: 208 additions & 39 deletions src/rt_win.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use windows_sys::s;
use windows_sys::Win32::Foundation::GetLastError;
use windows_sys::Win32::Foundation::FALSE;
use windows_sys::Win32::Foundation::HANDLE;
use windows_sys::Win32::System::Threading::{
AvRevertMmThreadCharacteristics, AvSetMmThreadCharacteristicsA,
};

use self::avrt_lib::AvRtLibrary;
use crate::AudioThreadPriorityError;

use log::info;
use std::sync::OnceLock;
use windows_sys::{
core::PCWSTR,
w,
Win32::Foundation::{HANDLE, WIN32_ERROR},
};

const MMCSS_TASK_NAME: PCWSTR = w!("Audio");

#[derive(Debug)]
pub struct RtPriorityHandleInternal {
Expand All @@ -17,53 +17,222 @@ pub struct RtPriorityHandleInternal {
}

impl RtPriorityHandleInternal {
pub fn new(mmcss_task_index: u32, task_handle: HANDLE) -> RtPriorityHandleInternal {
fn new(mmcss_task_index: u32, task_handle: HANDLE) -> RtPriorityHandleInternal {
RtPriorityHandleInternal {
mmcss_task_index,
task_handle,
}
}
}

pub fn demote_current_thread_from_real_time_internal(
rt_priority_handle: RtPriorityHandleInternal,
) -> Result<(), AudioThreadPriorityError> {
let rv = unsafe { AvRevertMmThreadCharacteristics(rt_priority_handle.task_handle) };
if rv == FALSE {
return Err(AudioThreadPriorityError::new(&format!(
"Unable to restore the thread priority ({:?})",
unsafe { GetLastError() }
)));
}

info!(
"task {} priority restored.",
rt_priority_handle.mmcss_task_index
);
fn avrt() -> Result<&'static AvRtLibrary, AudioThreadPriorityError> {
static AV_RT_LIBRARY: OnceLock<Result<AvRtLibrary, WIN32_ERROR>> = OnceLock::new();
AV_RT_LIBRARY
.get_or_init(AvRtLibrary::try_new)
.as_ref()
.map_err(|win32_error| {
AudioThreadPriorityError::new(&format!("Unable to load avrt.dll ({win32_error})"))
})
}

Ok(())
// We don't expect to fail to load the library on test machines
#[test]
fn test_successful_avrt_library_load_as_static_ref() {
assert!(avrt().is_ok())
}

pub fn promote_current_thread_to_real_time_internal(
_audio_buffer_frames: u32,
_audio_samplerate_hz: u32,
) -> Result<RtPriorityHandleInternal, AudioThreadPriorityError> {
let mut task_index = 0u32;
avrt()?
.set_mm_thread_characteristics(MMCSS_TASK_NAME)
.map(|(mmcss_task_index, task_handle)| {
info!("task {mmcss_task_index} bumped to real time priority.");
RtPriorityHandleInternal::new(mmcss_task_index, task_handle)
})
.map_err(|win32_error| {
AudioThreadPriorityError::new(&format!(
"Unable to bump the thread priority ({win32_error})"
))
})
}

let handle = unsafe { AvSetMmThreadCharacteristicsA(s!("Audio"), &mut task_index) };
let handle = RtPriorityHandleInternal::new(task_index, handle);
pub fn demote_current_thread_from_real_time_internal(
rt_priority_handle: RtPriorityHandleInternal,
) -> Result<(), AudioThreadPriorityError> {
let RtPriorityHandleInternal {
mmcss_task_index,
task_handle,
} = rt_priority_handle;
avrt()?
.revert_mm_thread_characteristics(task_handle)
.map(|_| {
info!("task {mmcss_task_index} priority restored.");
})
.map_err(|win32_error| {
AudioThreadPriorityError::new(&format!(
"Unable to restore the thread priority for task {mmcss_task_index} ({win32_error})"
))
})
}

// We don't expect to see API failures on test machines
#[test]
fn test_successful_api_use() {
let handle = promote_current_thread_to_real_time_internal(512, 44100);
println!("handle: {handle:?}");
assert!(handle.is_ok());

let result = demote_current_thread_from_real_time_internal(handle.unwrap());
println!("result: {result:?}");
assert!(result.is_ok());
}

if handle.task_handle == 0 {
return Err(AudioThreadPriorityError::new(&format!(
"Unable to restore the thread priority ({:?})",
unsafe { GetLastError() }
)));
mod avrt_lib {
use super::{
win32_utils::{win32_error_if, OwnedLibrary},
MMCSS_TASK_NAME,
};
use windows_sys::{
core::PCWSTR,
s, w,
Win32::Foundation::{BOOL, FALSE, HANDLE, WIN32_ERROR},
};

type AvSetMmThreadCharacteristicsWFn = unsafe extern "system" fn(PCWSTR, *mut u32) -> HANDLE;
type AvRevertMmThreadCharacteristicsFn = unsafe extern "system" fn(HANDLE) -> BOOL;

#[derive(Debug)]
pub(super) struct AvRtLibrary {
// This field is never read because only used for its Drop behavior
#[allow(dead_code)]
module: OwnedLibrary,

av_set_mm_thread_characteristics_w: AvSetMmThreadCharacteristicsWFn,
av_revert_mm_thread_characteristics: AvRevertMmThreadCharacteristicsFn,
}

info!(
"task {} bumped to real time priority.",
handle.mmcss_task_index
);
impl AvRtLibrary {
pub(super) fn try_new() -> Result<Self, WIN32_ERROR> {
let module = OwnedLibrary::try_new(w!("avrt.dll"))?;
let av_set_mm_thread_characteristics_w = unsafe {
std::mem::transmute::<_, AvSetMmThreadCharacteristicsWFn>(
module.get_proc(s!("AvSetMmThreadCharacteristicsW"))?,
)
};
let av_revert_mm_thread_characteristics = unsafe {
std::mem::transmute::<_, AvRevertMmThreadCharacteristicsFn>(
module.get_proc(s!("AvRevertMmThreadCharacteristics"))?,
)
};
Ok(Self::new(
module,
av_set_mm_thread_characteristics_w,
av_revert_mm_thread_characteristics,
))
}

fn new(
module: OwnedLibrary,
av_set_mm_thread_characteristics_w: AvSetMmThreadCharacteristicsWFn,
av_revert_mm_thread_characteristics: AvRevertMmThreadCharacteristicsFn,
) -> Self {
let library = AvRtLibrary {
module,
av_set_mm_thread_characteristics_w,
av_revert_mm_thread_characteristics,
};

// Warmup code to ensure that the MMCSS service will already be active once
// we return from this function.
//
// Note: This warmup code seems necessary to guarantee the thread safety of
// the avrt functions. Removing this warmup code can result in calls to
// AvSetMmThreadCharacteristicsW failing with ERROR_PATH_NOT_FOUND.
if let Ok((_, handle)) = library.set_mm_thread_characteristics(MMCSS_TASK_NAME) {
let _ = library.revert_mm_thread_characteristics(handle);
}

Ok(handle)
library
}

pub(super) fn set_mm_thread_characteristics(
&self,
task_name: PCWSTR,
) -> Result<(u32, HANDLE), WIN32_ERROR> {
let mut mmcss_task_index = 0u32;
let task_handle = unsafe {
(self.av_set_mm_thread_characteristics_w)(task_name, &mut mmcss_task_index)
};
win32_error_if(task_handle == 0)?;
Ok((mmcss_task_index, task_handle))
}

pub(super) fn revert_mm_thread_characteristics(
&self,
handle: HANDLE,
) -> Result<(), WIN32_ERROR> {
let rv = unsafe { (self.av_revert_mm_thread_characteristics)(handle) };
win32_error_if(rv == FALSE)
}
}

// We don't expect to fail to load the library on test machines
#[test]
fn test_successful_temporary_avrt_library_load() {
assert!(AvRtLibrary::try_new().is_ok())
}
}

mod win32_utils {
use windows_sys::{
core::{PCSTR, PCWSTR},
Win32::{
Foundation::{FreeLibrary, GetLastError, HMODULE, WIN32_ERROR},
System::LibraryLoader::{GetProcAddress, LoadLibraryW},
},
};

pub(super) fn win32_error_if(condition: bool) -> Result<(), WIN32_ERROR> {
if condition {
Err(unsafe { GetLastError() })
} else {
Ok(())
}
}

#[derive(Debug)]
pub(super) struct OwnedLibrary(HMODULE);

impl OwnedLibrary {
pub(super) fn try_new(lib_file_name: PCWSTR) -> Result<Self, WIN32_ERROR> {
let module = unsafe { LoadLibraryW(lib_file_name) };
win32_error_if(module == 0)?;
Ok(OwnedLibrary(module))
}

fn raw(&self) -> HMODULE {
self.0
}

/// SAFETY: The caller must transmute the value wrapped in a Ok(_) to the correct
/// function type, with the correct extern specifier.
pub(super) unsafe fn get_proc(
&self,
proc_name: PCSTR,
) -> Result<unsafe extern "system" fn() -> isize, WIN32_ERROR> {
let proc = unsafe { GetProcAddress(self.raw(), proc_name) };
win32_error_if(proc.is_none())?;
Ok(proc.unwrap())
}
}

impl Drop for OwnedLibrary {
fn drop(&mut self) {
unsafe {
FreeLibrary(self.raw());
}
}
}
}

0 comments on commit 783d697

Please sign in to comment.