Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit 4dc634d

Browse files
slyubomirskyjunrushao
authored andcommitted
[Bugfix][VM] Properly convert tensor inputs in save_function (#257)
It was observed that closures saved using `save_function` would crash when used over RPC with the `time_evaluator`, whereas using `set_input` and `invoke_stateful` worked as normal. While I am not entirely sure why these failures happened over RPC only in `time_evaluator` (but not in other RPC trials), it became clear that `set_input` performs a conversion of input tensor values in `SetInputTensorWithIndex`, while `save_function` was not doing this. Adding this conversion fixed the observed bug.
1 parent 39448c3 commit 4dc634d

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/runtime/relax_vm/vm.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
127127
if (args.size() > 3) {
128128
inputs = std::vector<RegType>(args.size() - 3);
129129
for (int i = 3; i < args.size(); i++) {
130-
inputs[i - 3] = args[i];
130+
SetInputTensorWithIndex(inputs, args[i], i - 3, devices[0]);
131131
}
132132
}
133133
if (include_return) {

tests/python/relax/test_vm.py

+18
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,24 @@ def test_save_function_kwargs_rpc():
12081208
run_on_rpc(TestVMSetInput, save_function_kwargs_trial)
12091209

12101210

1211+
def save_function_time_evaluator_trial(
1212+
vm: relax.VirtualMachine, device: tvm.runtime.Device
1213+
) -> None:
1214+
# just checking that the saved function can be called in the time evaluator
1215+
a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
1216+
b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device)
1217+
vm.save_function("main", "saved_main", a, b)
1218+
vm.time_evaluator("saved_main", device)()
1219+
1220+
1221+
def test_save_function_time_evaluator():
1222+
save_function_time_evaluator_trial(*make_vm(TestVMSetInput))
1223+
1224+
1225+
def test_save_function_time_evaluator():
1226+
run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial)
1227+
1228+
12111229
# if you set an input, you should not be able to call statelessly
12121230
@pytest.mark.xfail()
12131231
def test_set_input_stateless_failure():

0 commit comments

Comments
 (0)