Skip to content

Commit

Permalink
Support global.get in more constant expressions (#7996)
Browse files Browse the repository at this point in the history
This commit updates Wasmtime to support `global.get` in constant
expressions when located in table initializers and element segments.
Pre-reference-types this never came up because there was no valid
`global.get` that would typecheck. After the reference-types proposal
landed however this became possible but Wasmtime did not support it.
This was surfaced in #6705 when the spec test suite was updated and has
a new test that exercises this functionality.

This commit both updates the spec test suite and additionally adds
support for this new form of element segment and table initialization
expression.

The fact that Wasmtime hasn't supported this until now also means that
we have a gap in our fuzz-testing infrastructure. The `wasm-smith`
generator is being updated in bytecodealliance/wasm-tools#1426 to
generate modules with this particular feature and I've tested that with
that PR fuzzing here eventually generates an error before this PR.

Closes #6705
  • Loading branch information
alexcrichton committed Feb 26, 2024
1 parent 300fe46 commit 36fb62c
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 59 deletions.
53 changes: 48 additions & 5 deletions crates/environ/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ impl ModuleTranslation<'_> {
// Get the end of this segment. If out-of-bounds, or too
// large for our dense table representation, then skip the
// segment.
let top = match segment.offset.checked_add(segment.elements.len() as u32) {
let top = match segment.offset.checked_add(segment.elements.len()) {
Some(top) => top,
None => break,
};
Expand All @@ -482,6 +482,13 @@ impl ModuleTranslation<'_> {
WasmHeapType::Extern => break,
}

// Function indices can be optimized here, but fully general
// expressions are deferred to get evaluated at runtime.
let function_elements = match &segment.elements {
TableSegmentElements::Functions(indices) => indices,
TableSegmentElements::Expressions(_) => break,
};

let precomputed =
match &mut self.module.table_initialization.initial_values[defined_index] {
TableInitialValue::Null { precomputed } => precomputed,
Expand All @@ -492,7 +499,7 @@ impl ModuleTranslation<'_> {
// Technically this won't trap so it's possible to process
// further initializers, but that's left as a future
// optimization.
TableInitialValue::FuncRef(_) => break,
TableInitialValue::FuncRef(_) | TableInitialValue::GlobalGet(_) => break,
};

// At this point we're committing to pre-initializing the table
Expand All @@ -504,7 +511,7 @@ impl ModuleTranslation<'_> {
precomputed.resize(top as usize, FuncIndex::reserved_value());
}
let dst = &mut precomputed[(segment.offset as usize)..(top as usize)];
dst.copy_from_slice(&segment.elements[..]);
dst.copy_from_slice(&function_elements);

// advance the iterator to see the next segment
let _ = segments.next();
Expand Down Expand Up @@ -757,6 +764,10 @@ pub enum TableInitialValue {
/// Initialize each table element to the function reference given
/// by the `FuncIndex`.
FuncRef(FuncIndex),

/// At instantiation time this global is loaded and the funcref value is
/// used to initialize the table.
GlobalGet(GlobalIndex),
}

/// A WebAssembly table initializer segment.
Expand All @@ -769,7 +780,39 @@ pub struct TableSegment {
/// The offset to add to the base.
pub offset: u32,
/// The values to write into the table elements.
pub elements: Box<[FuncIndex]>,
pub elements: TableSegmentElements,
}

/// Elements of a table segment, either a list of functions or list of arbitrary
/// expressions.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TableSegmentElements {
/// A sequential list of functions where `FuncIndex::reserved_value()`
/// indicates a null function.
Functions(Box<[FuncIndex]>),
/// Arbitrary expressions, aka either functions, null or a load of a global.
Expressions(Box<[TableElementExpression]>),
}

impl TableSegmentElements {
/// Returns the number of elements in this segment.
pub fn len(&self) -> u32 {
match self {
Self::Functions(s) => s.len() as u32,
Self::Expressions(s) => s.len() as u32,
}
}
}

/// Different kinds of expression that can initialize table elements.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum TableElementExpression {
/// `ref.func $f`
Function(FuncIndex),
/// `global.get $g`
GlobalGet(GlobalIndex),
/// `ref.null $ty`
Null,
}

/// Different types that can appear in a module.
Expand Down Expand Up @@ -815,7 +858,7 @@ pub struct Module {
pub memory_initialization: MemoryInitialization,

/// WebAssembly passive elements.
pub passive_elements: Vec<Box<[FuncIndex]>>,
pub passive_elements: Vec<TableSegmentElements>,

/// The map from passive element index (element segment index space) to index in `passive_elements`.
pub passive_elements_map: BTreeMap<ElemIndex, usize>,
Expand Down
38 changes: 24 additions & 14 deletions crates/environ/src/module_environ.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use crate::module::{
FuncRefIndex, Initializer, MemoryInitialization, MemoryInitializer, MemoryPlan, Module,
ModuleType, TablePlan, TableSegment,
ModuleType, TableElementExpression, TablePlan, TableSegment, TableSegmentElements,
};
use crate::{
DataIndex, DefinedFuncIndex, ElemIndex, EntityIndex, EntityType, FuncIndex, GlobalIndex,
GlobalInit, MemoryIndex, ModuleTypesBuilder, PrimaryMap, TableIndex, TableInitialValue,
Tunables, TypeConvert, TypeIndex, Unsigned, WasmError, WasmHeapType, WasmResult, WasmValType,
WasmparserTypeConverter,
};
use cranelift_entity::packed_option::ReservedValue;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -320,6 +319,10 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
self.flag_func_escaped(index);
TableInitialValue::FuncRef(index)
}
Operator::GlobalGet { global_index } => {
let index = GlobalIndex::from_u32(global_index);
TableInitialValue::GlobalGet(index)
}
s => {
return Err(WasmError::Unsupported(format!(
"unsupported init expr in table section: {:?}",
Expand Down Expand Up @@ -449,25 +452,31 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
// possible to create anything other than a `ref.null
// extern` for externref segments, so those just get
// translated to the reserved value of `FuncIndex`.
let mut elements = Vec::new();
match items {
let elements = match items {
ElementItems::Functions(funcs) => {
elements.reserve(usize::try_from(funcs.count()).unwrap());
let mut elems =
Vec::with_capacity(usize::try_from(funcs.count()).unwrap());
for func in funcs {
let func = FuncIndex::from_u32(func?);
self.flag_func_escaped(func);
elements.push(func);
elems.push(func);
}
TableSegmentElements::Functions(elems.into())
}
ElementItems::Expressions(_ty, funcs) => {
elements.reserve(usize::try_from(funcs.count()).unwrap());
for func in funcs {
let func = match func?.get_binary_reader().read_operator()? {
Operator::RefNull { .. } => FuncIndex::reserved_value(),
ElementItems::Expressions(_ty, items) => {
let mut exprs =
Vec::with_capacity(usize::try_from(items.count()).unwrap());
for expr in items {
let expr = match expr?.get_binary_reader().read_operator()? {
Operator::RefNull { .. } => TableElementExpression::Null,
Operator::RefFunc { function_index } => {
let func = FuncIndex::from_u32(function_index);
self.flag_func_escaped(func);
func
TableElementExpression::Function(func)
}
Operator::GlobalGet { global_index } => {
let global = GlobalIndex::from_u32(global_index);
TableElementExpression::GlobalGet(global)
}
s => {
return Err(WasmError::Unsupported(format!(
Expand All @@ -476,10 +485,11 @@ impl<'a, 'data> ModuleEnvironment<'a, 'data> {
)));
}
};
elements.push(func);
exprs.push(expr);
}
TableSegmentElements::Expressions(exprs.into())
}
}
};

match kind {
ElementKind::Active {
Expand Down
86 changes: 61 additions & 25 deletions crates/runtime/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ use wasmtime_environ::ModuleInternedTypeIndex;
use wasmtime_environ::{
packed_option::ReservedValue, DataIndex, DefinedGlobalIndex, DefinedMemoryIndex,
DefinedTableIndex, ElemIndex, EntityIndex, EntityRef, EntitySet, FuncIndex, GlobalIndex,
GlobalInit, HostPtr, MemoryIndex, MemoryPlan, Module, PrimaryMap, TableIndex,
TableInitialValue, Trap, VMOffsets, WasmHeapType, WasmRefType, WasmValType, VMCONTEXT_MAGIC,
GlobalInit, HostPtr, MemoryIndex, MemoryPlan, Module, PrimaryMap, TableElementExpression,
TableIndex, TableInitialValue, TableSegmentElements, Trap, VMOffsets, WasmHeapType,
WasmRefType, WasmValType, VMCONTEXT_MAGIC,
};
#[cfg(feature = "wmemcheck")]
use wasmtime_wmemcheck::Wmemcheck;
Expand Down Expand Up @@ -804,50 +805,83 @@ impl Instance {
// disconnected from the lifetime of `self`.
let module = self.module().clone();

let empty = TableSegmentElements::Functions(Box::new([]));
let elements = match module.passive_elements_map.get(&elem_index) {
Some(index) if !self.dropped_elements.contains(elem_index) => {
module.passive_elements[*index].as_ref()
&module.passive_elements[*index]
}
_ => &[],
_ => &empty,
};
self.table_init_segment(table_index, elements, dst, src, len)
}

pub(crate) fn table_init_segment(
&mut self,
table_index: TableIndex,
elements: &[FuncIndex],
elements: &TableSegmentElements,
dst: u32,
src: u32,
len: u32,
) -> Result<(), Trap> {
// https://webassembly.github.io/bulk-memory-operations/core/exec/instructions.html#exec-table-init

let table = unsafe { &mut *self.get_table(table_index) };

let elements = match elements
.get(usize::try_from(src).unwrap()..)
.and_then(|s| s.get(..usize::try_from(len).unwrap()))
{
Some(elements) => elements,
None => return Err(Trap::TableOutOfBounds),
};

match table.element_type() {
TableElementType::Func => {
table.init_funcs(
let src = usize::try_from(src).map_err(|_| Trap::TableOutOfBounds)?;
let len = usize::try_from(len).map_err(|_| Trap::TableOutOfBounds)?;

match elements {
TableSegmentElements::Functions(funcs) => {
let elements = funcs
.get(src..)
.and_then(|s| s.get(..len))
.ok_or(Trap::TableOutOfBounds)?;
table.init(
dst,
elements
.iter()
.map(|idx| self.get_func_ref(*idx).unwrap_or(std::ptr::null_mut())),
elements.iter().map(|idx| {
TableElement::FuncRef(
self.get_func_ref(*idx).unwrap_or(std::ptr::null_mut()),
)
}),
)?;
}

TableElementType::Extern => {
debug_assert!(elements.iter().all(|e| *e == FuncIndex::reserved_value()));
table.fill(dst, TableElement::ExternRef(None), len)?;
TableSegmentElements::Expressions(exprs) => {
let ty = table.element_type();
let exprs = exprs
.get(src..)
.and_then(|s| s.get(..len))
.ok_or(Trap::TableOutOfBounds)?;
table.init(
dst,
exprs.iter().map(|expr| match ty {
TableElementType::Func => {
let funcref = match expr {
TableElementExpression::Null => std::ptr::null_mut(),
TableElementExpression::Function(idx) => {
self.get_func_ref(*idx).unwrap()
}
TableElementExpression::GlobalGet(idx) => {
let global = self.defined_or_imported_global_ptr(*idx);
unsafe { (*global).as_func_ref() }
}
};
TableElement::FuncRef(funcref)
}
TableElementType::Extern => {
let externref = match expr {
TableElementExpression::Null => None,
TableElementExpression::Function(_) => unreachable!(),
TableElementExpression::GlobalGet(idx) => {
let global = self.defined_or_imported_global_ptr(*idx);
unsafe { (*global).as_externref().clone() }
}
};
TableElement::ExternRef(externref)
}
}),
)?;
}
}

Ok(())
}

Expand Down Expand Up @@ -1060,7 +1094,9 @@ impl Instance {
let module = self.module();
let precomputed = match &module.table_initialization.initial_values[idx] {
TableInitialValue::Null { precomputed } => precomputed,
TableInitialValue::FuncRef(_) => unreachable!(),
TableInitialValue::FuncRef(_) | TableInitialValue::GlobalGet(_) => {
unreachable!()
}
};
let func_index = precomputed.get(i as usize).cloned();
let func_ref = func_index
Expand Down
11 changes: 9 additions & 2 deletions crates/runtime/src/instance/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ fn check_table_init_bounds(instance: &mut Instance, module: &Module) -> Result<(
let table = unsafe { &*instance.get_table(segment.table_index) };
let start = get_table_init_start(segment, instance)?;
let start = usize::try_from(start).unwrap();
let end = start.checked_add(segment.elements.len());
let end = start.checked_add(usize::try_from(segment.elements.len()).unwrap());

match end {
Some(end) if end <= table.size() as usize => {
Expand All @@ -533,6 +533,13 @@ fn initialize_tables(instance: &mut Instance, module: &Module) -> Result<()> {
let table = unsafe { &mut *instance.get_defined_table(table) };
table.init_func(funcref)?;
}

TableInitialValue::GlobalGet(idx) => unsafe {
let global = instance.defined_or_imported_global_ptr(*idx);
let funcref = (*global).as_func_ref();
let table = &mut *instance.get_defined_table(table);
table.init_func(funcref)?;
},
}
}

Expand All @@ -550,7 +557,7 @@ fn initialize_tables(instance: &mut Instance, module: &Module) -> Result<()> {
&segment.elements,
start,
0,
segment.elements.len() as u32,
segment.elements.len(),
)?;
}

Expand Down

0 comments on commit 36fb62c

Please sign in to comment.