Skip to content

Commit 8a3b7f7

Browse files
committed
More precise return types
1 parent 0fa14b9 commit 8a3b7f7

File tree

8 files changed

+246
-53
lines changed

8 files changed

+246
-53
lines changed

compiler/lib-wasm/code_generation.ml

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
193193
(* I31, struct, array and none have no other subtype *)
194194
| _, (I31 | Type _ | Struct | Array | None_) -> false, st
195195

196+
let rec type_index_lub ty ty' st =
197+
(* Find the LUB efficiently by taking advantage of the fact that
198+
types are defined after their supertypes, making their variables
199+
compare greater. *)
200+
let c = Var.compare ty ty' in
201+
if c > 0
202+
then type_index_lub ty' ty st
203+
else if c = 0
204+
then Some ty
205+
else
206+
let type_field = Var.Hashtbl.find st.context.types ty' in
207+
match type_field.supertype with
208+
| None -> None
209+
| Some ty'' ->
210+
assert (Var.compare ty'' ty' < 0);
211+
type_index_lub ty ty'' st
212+
213+
let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
214+
match ty, ty' with
215+
| (Func | Extern), _ | _, (Func | Extern) -> assert false
216+
| None_, _ -> return ty'
217+
| _, None_ | Struct, Struct | Array, Array -> return ty
218+
| Any, _ | _, Any -> return W.Any
219+
| Eq, _
220+
| _, Eq
221+
| (Struct | Array | Type _), I31
222+
| I31, (Struct | Array | Type _)
223+
| Struct, Array
224+
| Array, Struct -> return (Eq : W.heap_type)
225+
| Struct, Type t | Type t, Struct -> (
226+
fun st ->
227+
let type_field = Var.Hashtbl.find st.context.types t in
228+
match type_field.typ with
229+
| Struct _ -> W.Struct, st
230+
| Array _ | Func _ -> W.Eq, st)
231+
| Array, Type t | Type t, Array -> (
232+
fun st ->
233+
let type_field = Var.Hashtbl.find st.context.types t in
234+
match type_field.typ with
235+
| Array _ -> W.Struct, st
236+
| Struct _ | Func _ -> W.Eq, st)
237+
| Type t, Type t' -> (
238+
let* r = fun st -> type_index_lub t t' st, st in
239+
match r with
240+
| Some t'' -> return (Type t'' : W.heap_type)
241+
| None -> (
242+
fun st ->
243+
let type_field = Var.Hashtbl.find st.context.types t in
244+
let type_field' = Var.Hashtbl.find st.context.types t' in
245+
match type_field.typ, type_field'.typ with
246+
| Struct _, Struct _ -> (Struct : W.heap_type), st
247+
| Array _, Array _ -> W.Array, st
248+
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
249+
| I31, I31 -> return W.I31
250+
251+
let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
252+
match ty, ty' with
253+
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
254+
let* typ = heap_type_lub typ typ' in
255+
return (W.Ref { nullable = nullable || nullable'; typ })
256+
| _ -> assert false
257+
196258
let register_global name ?exported_name ?(constant = false) typ init st =
197259
st.context.other_fields <-
198260
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
@@ -705,7 +767,7 @@ let init_code context = instrs context.init_code
705767

706768
let function_body ~context ~param_names ~body =
707769
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
708-
let (), st = body st in
770+
let res, st = body st in
709771
let local_count, body = st.var_count, List.rev st.instrs in
710772
let local_types = Array.make local_count (Var.fresh (), None) in
711773
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
@@ -723,4 +785,10 @@ let function_body ~context ~param_names ~body =
723785
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
724786
|> Array.to_list
725787
in
726-
locals, body
788+
locals, res, body
789+
790+
let eval ~context e =
791+
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
792+
let r, st = e st in
793+
assert (st.var_count = 0 && List.is_empty st.instrs);
794+
r

compiler/lib-wasm/code_generation.mli

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ val register_type : string -> (unit -> type_def t) -> Wasm_ast.var t
155155

156156
val heap_type_sub : Wasm_ast.heap_type -> Wasm_ast.heap_type -> bool t
157157

158+
val value_type_lub : Wasm_ast.value_type -> Wasm_ast.value_type -> Wasm_ast.value_type t
159+
158160
val register_import :
159161
?import_module:string -> name:string -> Wasm_ast.import_desc -> Wasm_ast.var t
160162

@@ -195,8 +197,8 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t
195197
val function_body :
196198
context:context
197199
-> param_names:Code.Var.t list
198-
-> body:unit t
199-
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
200+
-> body:'a t
201+
-> (Wasm_ast.var * Wasm_ast.value_type) list * 'a * Wasm_ast.instruction list
200202

201203
val variable_type : Code.Var.t -> Wasm_ast.value_type option t
202204

@@ -207,3 +209,5 @@ val array_placeholder : Code.Var.t -> expression
207209
val default_value :
208210
Wasm_ast.value_type
209211
-> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t
212+
213+
val eval : context:context -> 'a t -> 'a

compiler/lib-wasm/curry.ml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ module Make (Target : Target_sig.S) = struct
9595
loop m [] f None
9696
in
9797
let param_names = args @ [ f ] in
98-
let locals, body = function_body ~context ~param_names ~body in
98+
let locals, _, body = function_body ~context ~param_names ~body in
9999
W.Function
100100
{ name
101101
; exported_name = None
102-
; typ = None
102+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
103103
; signature = Type.func_type 1
104104
; param_names
105105
; locals
@@ -130,11 +130,11 @@ module Make (Target : Target_sig.S) = struct
130130
push (Closure.curry_allocate ~cps:false ~arity m ~f:name' ~closure:f ~arg:x)
131131
in
132132
let param_names = [ x; f ] in
133-
let locals, body = function_body ~context ~param_names ~body in
133+
let locals, _, body = function_body ~context ~param_names ~body in
134134
W.Function
135135
{ name
136136
; exported_name = None
137-
; typ = None
137+
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
138138
; signature = Type.func_type 1
139139
; param_names
140140
; locals
@@ -181,11 +181,11 @@ module Make (Target : Target_sig.S) = struct
181181
loop m [] f None
182182
in
183183
let param_names = args @ [ f ] in
184-
let locals, body = function_body ~context ~param_names ~body in
184+
let locals, _, body = function_body ~context ~param_names ~body in
185185
W.Function
186186
{ name
187187
; exported_name = None
188-
; typ = None
188+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
189189
; signature = Type.func_type 2
190190
; param_names
191191
; locals
@@ -220,11 +220,11 @@ module Make (Target : Target_sig.S) = struct
220220
instr (W.Return (Some c))
221221
in
222222
let param_names = [ x; cont; f ] in
223-
let locals, body = function_body ~context ~param_names ~body in
223+
let locals, _, body = function_body ~context ~param_names ~body in
224224
W.Function
225225
{ name
226226
; exported_name = None
227-
; typ = None
227+
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
228228
; signature = Type.func_type 2
229229
; param_names
230230
; locals
@@ -264,7 +264,7 @@ module Make (Target : Target_sig.S) = struct
264264
build_applies (load f) l)
265265
in
266266
let param_names = l @ [ f ] in
267-
let locals, body = function_body ~context ~param_names ~body in
267+
let locals, _, body = function_body ~context ~param_names ~body in
268268
W.Function
269269
{ name
270270
; exported_name = None
@@ -312,7 +312,7 @@ module Make (Target : Target_sig.S) = struct
312312
push (call ~cps:true ~arity:2 (load f) [ x; iterate ]))
313313
in
314314
let param_names = l @ [ f ] in
315-
let locals, body = function_body ~context ~param_names ~body in
315+
let locals, _, body = function_body ~context ~param_names ~body in
316316
W.Function
317317
{ name
318318
; exported_name = None
@@ -347,11 +347,13 @@ module Make (Target : Target_sig.S) = struct
347347
instr (W.Return (Some e))
348348
in
349349
let param_names = l @ [ f ] in
350-
let locals, body = function_body ~context ~param_names ~body in
350+
let locals, _, body = function_body ~context ~param_names ~body in
351351
W.Function
352352
{ name
353353
; exported_name = None
354-
; typ = None
354+
; typ =
355+
Some
356+
(eval ~context (Type.function_type ~cps (if cps then arity - 1 else arity)))
355357
; signature = Type.func_type arity
356358
; param_names
357359
; locals

compiler/lib-wasm/gc_target.ml

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,35 @@ module Type = struct
205205
let primitive_type n =
206206
{ W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] }
207207

208-
let func_type n = primitive_type (n + 1)
209-
210-
let function_type ~cps n =
211-
let n = if cps then n + 1 else n in
212-
register_type (Printf.sprintf "function_%d" n) (fun () ->
213-
return { supertype = None; final = true; typ = W.Func (func_type n) })
208+
let func_type ?(ret = value) n =
209+
{ W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ ret ] }
210+
211+
let rec function_type ~cps ?ret n =
212+
let n' = if cps then n + 1 else n in
213+
let ret_str =
214+
match ret with
215+
| None -> ""
216+
| Some (W.Ref { nullable = false; typ }) -> (
217+
match typ with
218+
| Eq -> "_eq" (*ZZZ remove ret in that case*)
219+
| I31 -> "_i31"
220+
| Struct -> "_struct"
221+
| Array -> "_array"
222+
| None_ -> "_none"
223+
| Type v -> (
224+
match Code.Var.get_name v with
225+
| None -> assert false
226+
| Some name -> "_" ^ name)
227+
| _ -> assert false)
228+
| _ -> assert false
229+
in
230+
register_type (Printf.sprintf "function_%d%s" n' ret_str) (fun () ->
231+
match ret with
232+
| None -> return { supertype = None; final = false; typ = W.Func (func_type n') }
233+
| Some ret ->
234+
let* super = function_type ~cps n in
235+
return
236+
{ supertype = Some super; final = false; typ = W.Func (func_type ~ret n') })
214237

215238
let closure_common_fields ~cps =
216239
let* fun_ty = function_type ~cps 1 in

0 commit comments

Comments
 (0)