diff --git a/platforms/python/examples/simple-benchmark.py b/platforms/python/examples/simple-benchmark.py index 5246473..e0b3ceb 100755 --- a/platforms/python/examples/simple-benchmark.py +++ b/platforms/python/examples/simple-benchmark.py @@ -4,12 +4,11 @@ import wasm3 import time, timeit # WebAssembly binary -WASM = bytes.fromhex(""" - 00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e - 03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01 - 1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42 - 02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b -""") +WASM = bytes.fromhex( + "00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e" + "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" + "1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42" + "02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b") (N, RES, CYCLES) = (24, 46368, 1000) diff --git a/platforms/python/m3module.c b/platforms/python/m3module.c index 9f0425f..ab071d8 100644 --- a/platforms/python/m3module.c +++ b/platforms/python/m3module.c @@ -205,8 +205,8 @@ m3ApiRawFunction(CallImport) { PyObject *pFunc = (PyObject *)(_ctx->userdata); IM3Function f = _ctx->function; - int nArgs = m3_GetArgCount(_ctx->function); - int nRets = m3_GetRetCount(_ctx->function); + int nArgs = m3_GetArgCount(f); + int nRets = m3_GetRetCount(f); PyObject *pArgs = PyTuple_New(nArgs); if (!pArgs) { m3ApiTrap("python call: args not allocated"); diff --git a/platforms/python/test/test_m3.py b/platforms/python/test/test_m3.py index b4bca2a..95afaa7 100644 --- a/platforms/python/test/test_m3.py +++ b/platforms/python/test/test_m3.py @@ -2,25 +2,90 @@ import wasm3 as m3 import pytest FIB32_WASM = bytes.fromhex( - "00 61 73 6d 01 00 00 00 01 06 01 60 01 7f 01 7f" - "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" - "1d 00 20 00 41 02 49 04 40 20 00 0f 0b 20 00 41" - "02 6b 10 00 20 00 41 01 6b 10 00 6a 0f 0b") + "00 61 73 6d 01 00 00 00 01 06 01 60 01 7f 01 7f" + "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" + "1d 00 20 00 41 02 49 04 40 20 00 0f 0b 20 00 41" + "02 6b 10 00 20 00 41 01 6b 10 00 6a 0f 0b") FIB64_WASM = bytes.fromhex( - "00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e" - "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" - "1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42" - "02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b") - -# (module -# (func (param i64 i64) (result i64) -# local.get 0 -# local.get 1 -# i64.add -# return -# ) -# (export "add" (func 0))) + "00 61 73 6d 01 00 00 00 01 06 01 60 01 7e 01 7e" + "03 02 01 00 07 07 01 03 66 69 62 00 00 0a 1f 01" + "1d 00 20 00 42 02 54 04 40 20 00 0f 0b 20 00 42" + "02 7d 10 00 20 00 42 01 7d 10 00 7c 0f 0b") + +""" +(type (;0;) (func (param i32 i32) (result i32))) +(func $i (import "env" "callback") (type 0)) +(func (export "run_callback") (type 0) + local.get 0 + local.get 1 + call $i) +""" +CALLBACK_WASM = bytes.fromhex( + "00 61 73 6d 01 00 00 00 01 07 01 60 02 7f 7f 01" + "7f 02 10 01 03 65 6e 76 08 63 61 6c 6c 62 61 63" + "6b 00 00 03 02 01 00 07 10 01 0c 72 75 6e 5f 63" + "61 6c 6c 62 61 63 6b 00 01 0a 0a 01 08 00 20 00" + "20 01 10 00 0b") + +""" +(module + (type $t0 (func (param i32 i32) (result i32))) + (type $t1 (func)) + (type $t2 (func (param i32))) + (type $t3 (func (param i32 i32 i32) (result i32))) + (import "env" "pass_fptr" (func $env.pass_fptr (type $t2))) + (import "env" "__table_base" (global $env.__table_base i32)) + (func $run_test (export "run_test") (type $t1) + global.get $env.__table_base + call $env.pass_fptr + global.get $env.__table_base + i32.const 1 + i32.add + call $env.pass_fptr) + (func $f2 (type $t0) (param $p0 i32) (param $p1 i32) (result i32) + local.get $p0 + local.get $p1 + i32.add) + (func $f3 (type $t0) (param $p0 i32) (param $p1 i32) (result i32) + local.get $p0 + local.get $p1 + i32.mul) + (func $test (export "call_pass_fptr") (type $t2) (param $p0 i32) + local.get $p0 + call $env.pass_fptr + ) + (func $dynCall_iii (export "dynCall_iii") (type $t3) (param $p0 i32) (param $p1 i32) (param $p2 i32) (result i32) + local.get $p1 + local.get $p2 + local.get $p0 + call_indirect $table (type $t0)) + (table $table (export "table") 2 funcref) + (elem (global.get $env.__table_base) func $f2 $f3)) +""" +DYN_CALLBACK_WASM = bytes.fromhex( + "00 61 73 6d 01 00 00 00 01 15 04 60 02 7f 7f 01" + "7f 60 00 00 60 01 7f 00 60 03 7f 7f 7f 01 7f 02" + "25 02 03 65 6e 76 09 70 61 73 73 5f 66 70 74 72" + "00 02 03 65 6e 76 0c 5f 5f 74 61 62 6c 65 5f 62" + "61 73 65 03 7f 00 03 06 05 01 00 00 02 03 04 04" + "01 70 00 02 07 33 04 08 72 75 6e 5f 74 65 73 74" + "00 01 0e 63 61 6c 6c 5f 70 61 73 73 5f 66 70 74" + "72 00 04 0b 64 79 6e 43 61 6c 6c 5f 69 69 69 00" + "05 05 74 61 62 6c 65 01 00 09 08 01 00 23 00 0b" + "02 02 03 0a 32 05 0d 00 23 00 10 00 23 00 41 01" + "6a 10 00 0b 07 00 20 00 20 01 6a 0b 07 00 20 00" + "20 01 6c 0b 06 00 20 00 10 00 0b 0b 00 20 01 20" + "02 20 00 11 00 00 0b") + +""" +(func (export "add") (param i64 i64) (result i64) + local.get 0 + local.get 1 + i64.add + return +) +""" ADD_WASM = bytes.fromhex( "00 61 73 6d 01 00 00 00 01 07 01 60 02 7e 7e 01" "7e 03 02 01 00 07 07 01 03 61 64 64 00 00 0a 0a" @@ -32,6 +97,71 @@ def test_classes(): assert isinstance(m3.Module, type) assert isinstance(m3.Function, type) +def test_callback(): + env = m3.Environment() + rt = env.new_runtime(1024) + mod = env.parse_module(CALLBACK_WASM) + rt.load(mod) + mem = rt.get_memory(0) + + def func(x, y): + assert x == 123 + assert y == 456 + return x*y + mod.link_function("env", "callback", "i(ii)", func) + run_callback = rt.find_function("run_callback") + assert run_callback(123, 456) == 123*456 + +def test_callback_member(): + class WasmRunner: + def __init__(self, wasm): + self.env = m3.Environment() + self.rt = self.env.new_runtime(1024) + self.mod = self.env.parse_module(wasm) + self.rt.load(self.mod) + self.mem = self.rt.get_memory(0) + self.mod.link_function("env", "callback", "i(ii)", self.func) + self.run_callback = self.rt.find_function("run_callback") + + def func(self, x, y): + assert x == 987 + assert y == 654 + return x+y + + inst = WasmRunner(CALLBACK_WASM) + assert inst.run_callback(987, 654) == 987+654 + +def test_dynamic_callback(): + env = m3.Environment() + rt = env.new_runtime(1024) + mod = env.parse_module(DYN_CALLBACK_WASM) + rt.load(mod) + dynCall_iii = rt.find_function("dynCall_iii") + + def pass_fptr(fptr): + if fptr == 0: + assert dynCall_iii(fptr, 12, 34) == 46 + elif fptr == 1: + # TODO: call by table index directly here + assert dynCall_iii(fptr, 12, 34) == 408 + else: + raise Exception("Strange function ptr") + + mod.link_function("env", "pass_fptr", "v(i)", pass_fptr) + + # Indirect calls + assert dynCall_iii(0, 12, 34) == 46 + assert dynCall_iii(1, 12, 34) == 408 + + # Recursive exported function call (single calls) + call_pass_fptr = rt.find_function("call_pass_fptr") + base = 0 + call_pass_fptr(base+0) + call_pass_fptr(base+1) + + # Recursive exported function call (multiple calls) + rt.find_function("run_test")() + def test_m3(capfd): env = m3.Environment() rt = env.new_runtime(1024) diff --git a/source/m3_env.h b/source/m3_env.h index feeb9a4..acc56f8 100644 --- a/source/m3_env.h +++ b/source/m3_env.h @@ -220,7 +220,7 @@ typedef struct M3Runtime void * stack; u32 stackSize; u32 numStackSlots; - IM3Function lastCalled; // last function that successfully executed + IM3Function lastCalled; // last function that successfully executed void * userdata; diff --git a/source/m3_exec.h b/source/m3_exec.h index fcc4050..177477d 100644 --- a/source/m3_exec.h +++ b/source/m3_exec.h @@ -545,6 +545,8 @@ d_m3Op (CallRawFunction) ctx.userdata = immediate (void *); u64* const sp = ((u64*)_sp); + IM3Runtime runtime = m3MemRuntime(_mem); + #if d_m3EnableStrace IM3FuncType ftype = ctx.function->funcType; @@ -570,7 +572,13 @@ d_m3Op (CallRawFunction) } #endif - m3ret_t possible_trap = call (m3MemRuntime(_mem), &ctx, sp, m3MemData(_mem)); + // m3_Call uses runtime->stack to set-up initial exported function stack. + // Reconfigure the stack to enable recursive invocations of m3_Call. + // I.e. exported/table function can be called from an impoted function. + void* stack_backup = runtime->stack; + runtime->stack = sp; + m3ret_t possible_trap = call (runtime, &ctx, sp, m3MemData(_mem)); + runtime->stack = stack_backup; #if d_m3EnableStrace if (possible_trap) {