Enable recursive calls: calling exported function from an imported.

extensions
Volodymyr Shymanskyy 3 years ago
parent ca7af03efe
commit ccab6a438b

@ -4,12 +4,11 @@ import wasm3
import time, timeit import time, timeit
# WebAssembly binary # WebAssembly binary
WASM = bytes.fromhex(""" WASM = bytes.fromhex(
00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e "00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e"
03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01 "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01"
1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42 "1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42"
02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b "02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b")
""")
(N, RES, CYCLES) = (24, 46368, 1000) (N, RES, CYCLES) = (24, 46368, 1000)

@ -205,8 +205,8 @@ m3ApiRawFunction(CallImport)
{ {
PyObject *pFunc = (PyObject *)(_ctx->userdata); PyObject *pFunc = (PyObject *)(_ctx->userdata);
IM3Function f = _ctx->function; IM3Function f = _ctx->function;
int nArgs = m3_GetArgCount(_ctx->function); int nArgs = m3_GetArgCount(f);
int nRets = m3_GetRetCount(_ctx->function); int nRets = m3_GetRetCount(f);
PyObject *pArgs = PyTuple_New(nArgs); PyObject *pArgs = PyTuple_New(nArgs);
if (!pArgs) { if (!pArgs) {
m3ApiTrap("python call: args not allocated"); m3ApiTrap("python call: args not allocated");

@ -2,25 +2,90 @@ import wasm3 as m3
import pytest import pytest
FIB32_WASM = bytes.fromhex( FIB32_WASM = bytes.fromhex(
"00 61 73 6d 01 00 00 00 01 06 01 60 01 7f 01 7f" "00 61 73 6d 01 00 00 00 01 06 01 60 01 7f 01 7f"
"03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01"
"1d 00 20 00 41 02 49 04 40 20 00 0f 0b 20 00 41" "1d 00 20 00 41 02 49 04 40 20 00 0f 0b 20 00 41"
"02 6b 10 00 20 00 41 01 6b 10 00 6a 0f 0b") "02 6b 10 00 20 00 41 01 6b 10 00 6a 0f 0b")
FIB64_WASM = bytes.fromhex( FIB64_WASM = bytes.fromhex(
"00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e" "00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e"
"03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01"
"1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42" "1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42"
"02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b") "02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b")
# (module """
# (func (param i64 i64) (result i64) (type (;0;) (func (param i32 i32) (result i32)))
# local.get 0 (func $i (import "env" "callback") (type 0))
# local.get 1 (func (export "run_callback") (type 0)
# i64.add local.get 0
# return local.get 1
# ) call $i)
# (export "add" (func 0))) """
CALLBACK_WASM = bytes.fromhex(
"00 61 73 6d 01 00 00 00 01 07 01 60 02 7f 7f 01"
"7f 02 10 01 03 65 6e 76 08 63 61 6c 6c 62 61 63"
"6b 00 00 03 02 01 00 07 10 01 0c 72 75 6e 5f 63"
"61 6c 6c 62 61 63 6b 00 01 0a 0a 01 08 00 20 00"
"20 01 10 00 0b")
"""
(module
(type $t0 (func (param i32 i32) (result i32)))
(type $t1 (func))
(type $t2 (func (param i32)))
(type $t3 (func (param i32 i32 i32) (result i32)))
(import "env" "pass_fptr" (func $env.pass_fptr (type $t2)))
(import "env" "__table_base" (global $env.__table_base i32))
(func $run_test (export "run_test") (type $t1)
global.get $env.__table_base
call $env.pass_fptr
global.get $env.__table_base
i32.const 1
i32.add
call $env.pass_fptr)
(func $f2 (type $t0) (param $p0 i32) (param $p1 i32) (result i32)
local.get $p0
local.get $p1
i32.add)
(func $f3 (type $t0) (param $p0 i32) (param $p1 i32) (result i32)
local.get $p0
local.get $p1
i32.mul)
(func $test (export "call_pass_fptr") (type $t2) (param $p0 i32)
local.get $p0
call $env.pass_fptr
)
(func $dynCall_iii (export "dynCall_iii") (type $t3) (param $p0 i32) (param $p1 i32) (param $p2 i32) (result i32)
local.get $p1
local.get $p2
local.get $p0
call_indirect $table (type $t0))
(table $table (export "table") 2 funcref)
(elem (global.get $env.__table_base) func $f2 $f3))
"""
DYN_CALLBACK_WASM = bytes.fromhex(
"00 61 73 6d 01 00 00 00 01 15 04 60 02 7f 7f 01"
"7f 60 00 00 60 01 7f 00 60 03 7f 7f 7f 01 7f 02"
"25 02 03 65 6e 76 09 70 61 73 73 5f 66 70 74 72"
"00 02 03 65 6e 76 0c 5f 5f 74 61 62 6c 65 5f 62"
"61 73 65 03 7f 00 03 06 05 01 00 00 02 03 04 04"
"01 70 00 02 07 33 04 08 72 75 6e 5f 74 65 73 74"
"00 01 0e 63 61 6c 6c 5f 70 61 73 73 5f 66 70 74"
"72 00 04 0b 64 79 6e 43 61 6c 6c 5f 69 69 69 00"
"05 05 74 61 62 6c 65 01 00 09 08 01 00 23 00 0b"
"02 02 03 0a 32 05 0d 00 23 00 10 00 23 00 41 01"
"6a 10 00 0b 07 00 20 00 20 01 6a 0b 07 00 20 00"
"20 01 6c 0b 06 00 20 00 10 00 0b 0b 00 20 01 20"
"02 20 00 11 00 00 0b")
"""
(func (export "add") (param i64 i64) (result i64)
local.get 0
local.get 1
i64.add
return
)
"""
ADD_WASM = bytes.fromhex( ADD_WASM = bytes.fromhex(
"00 61 73 6d 01 00 00 00 01 07 01 60 02 7e 7e 01" "00 61 73 6d 01 00 00 00 01 07 01 60 02 7e 7e 01"
"7e 03 02 01 00 07 07 01 03 61 64 64 00 00 0a 0a" "7e 03 02 01 00 07 07 01 03 61 64 64 00 00 0a 0a"
@ -32,6 +97,71 @@ def test_classes():
assert isinstance(m3.Module, type) assert isinstance(m3.Module, type)
assert isinstance(m3.Function, type) assert isinstance(m3.Function, type)
def test_callback():
env = m3.Environment()
rt = env.new_runtime(1024)
mod = env.parse_module(CALLBACK_WASM)
rt.load(mod)
mem = rt.get_memory(0)
def func(x, y):
assert x == 123
assert y == 456
return x*y
mod.link_function("env", "callback", "i(ii)", func)
run_callback = rt.find_function("run_callback")
assert run_callback(123, 456) == 123*456
def test_callback_member():
class WasmRunner:
def __init__(self, wasm):
self.env = m3.Environment()
self.rt = self.env.new_runtime(1024)
self.mod = self.env.parse_module(wasm)
self.rt.load(self.mod)
self.mem = self.rt.get_memory(0)
self.mod.link_function("env", "callback", "i(ii)", self.func)
self.run_callback = self.rt.find_function("run_callback")
def func(self, x, y):
assert x == 987
assert y == 654
return x+y
inst = WasmRunner(CALLBACK_WASM)
assert inst.run_callback(987, 654) == 987+654
def test_dynamic_callback():
env = m3.Environment()
rt = env.new_runtime(1024)
mod = env.parse_module(DYN_CALLBACK_WASM)
rt.load(mod)
dynCall_iii = rt.find_function("dynCall_iii")
def pass_fptr(fptr):
if fptr == 0:
assert dynCall_iii(fptr, 12, 34) == 46
elif fptr == 1:
# TODO: call by table index directly here
assert dynCall_iii(fptr, 12, 34) == 408
else:
raise Exception("Strange function ptr")
mod.link_function("env", "pass_fptr", "v(i)", pass_fptr)
# Indirect calls
assert dynCall_iii(0, 12, 34) == 46
assert dynCall_iii(1, 12, 34) == 408
# Recursive exported function call (single calls)
call_pass_fptr = rt.find_function("call_pass_fptr")
base = 0
call_pass_fptr(base+0)
call_pass_fptr(base+1)
# Recursive exported function call (multiple calls)
rt.find_function("run_test")()
def test_m3(capfd): def test_m3(capfd):
env = m3.Environment() env = m3.Environment()
rt = env.new_runtime(1024) rt = env.new_runtime(1024)

@ -220,7 +220,7 @@ typedef struct M3Runtime
void * stack; void * stack;
u32 stackSize; u32 stackSize;
u32 numStackSlots; u32 numStackSlots;
IM3Function lastCalled; // last function that successfully executed IM3Function lastCalled; // last function that successfully executed
void * userdata; void * userdata;

@ -545,6 +545,8 @@ d_m3Op (CallRawFunction)
ctx.userdata = immediate (void *); ctx.userdata = immediate (void *);
u64* const sp = ((u64*)_sp); u64* const sp = ((u64*)_sp);
IM3Runtime runtime = m3MemRuntime(_mem);
#if d_m3EnableStrace #if d_m3EnableStrace
IM3FuncType ftype = ctx.function->funcType; IM3FuncType ftype = ctx.function->funcType;
@ -570,7 +572,13 @@ d_m3Op (CallRawFunction)
} }
#endif #endif
m3ret_t possible_trap = call (m3MemRuntime(_mem), &ctx, sp, m3MemData(_mem)); // m3_Call uses runtime->stack to set-up initial exported function stack.
// Reconfigure the stack to enable recursive invocations of m3_Call.
// I.e. exported/table function can be called from an impoted function.
void* stack_backup = runtime->stack;
runtime->stack = sp;
m3ret_t possible_trap = call (runtime, &ctx, sp, m3MemData(_mem));
runtime->stack = stack_backup;
#if d_m3EnableStrace #if d_m3EnableStrace
if (possible_trap) { if (possible_trap) {

Loading…
Cancel
Save