Skip to content
35 changes: 28 additions & 7 deletions bin/lwt_to_direct_style/ast_rewrite.ml
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,22 @@ let lwt_io_mode_of_ast ~state mode =
| _ -> None)
| _ -> None

let lwt_io_of_fd ~backend ~state ~mode fd =
let lwt_io_open ~backend ~state ~mode src =
match lwt_io_mode_of_ast ~state mode with
| Some `Input -> Some (backend#input_io_of_fd fd)
| Some `Output -> Some (backend#output_io_of_fd fd)
| Some `Input -> Some (backend#input_io src)
| Some `Output -> Some (backend#output_io src)
| None ->
add_comment state
"Couldn't translate this call to [Lwt_io.of_fd] because the [~mode] \
argument couldn't be decoded. Directly use [Lwt_io.input] or \
[Lwt_io.output].";
None

let lwt_io_read ~backend ~state:_ count in_chan =
match count with
| Some count_arg -> backend#io_read_string_count in_chan count_arg
| None -> Some (backend#io_read_all in_chan)

let mk_cstr c = Some (mk_constr_exp [ c ])

(* Rewrite calls to functions from the [Lwt] module. See [rewrite_apply] for
Expand Down Expand Up @@ -322,10 +327,9 @@ let rewrite_apply ~backend ~state full_ident args =
take @@ fun d ->
take @@ fun f -> return (Some (backend#with_timeout d f))
| "Lwt_unix", "of_unix_file_descr" ->
take @@ fun fd ->
take_lblopt "blocking" @@ fun blocking ->
ignore_lblarg "set_flags"
@@ return (Some (backend#of_unix_file_descr ?blocking fd))
ignore_lblarg "set_flags" @@ take
@@ fun fd -> return (Some (backend#of_unix_file_descr ?blocking fd))
| "Lwt_unix", "close" -> take @@ fun fd -> return (Some (backend#fd_close fd))
(* [Lwt_unix] contains functions exactly equivalent to functions of the same
name in [Unix]. *)
Expand All @@ -334,6 +338,10 @@ let rewrite_apply ~backend ~state full_ident args =
"This call to [Unix.%s] was [Lwt_unix.%s] before the rewrite." fname
fname;
transparent [ "Unix"; fname ]
| "Lwt_unix", "stat" ->
take @@ fun path -> return (Some (backend#path_stat ~follow:true path))
| "Lwt_unix", "lstat" ->
take @@ fun path -> return (Some (backend#path_stat ~follow:false path))
| "Lwt_condition", "create" ->
take @@ fun _unit -> return (Some (backend#condition_create ()))
| "Lwt_condition", "wait" ->
Expand All @@ -359,10 +367,23 @@ let rewrite_apply ~backend ~state full_ident args =
@@ ignore_lblarg ~cmt:"Will behave as if it was [true]." "close"
@@ take_lbl "mode"
@@ fun mode ->
take @@ fun fd -> return (lwt_io_of_fd ~backend ~state ~mode fd)
take @@ fun fd -> return (lwt_io_open ~backend ~state ~mode (`Of_fd fd))
| "Lwt_io", "open_file" ->
ignore_lblarg "buffer" @@ ignore_lblarg "flags" @@ ignore_lblarg "perm"
@@ take_lbl "mode"
@@ fun mode ->
take @@ fun fname ->
return (lwt_io_open ~backend ~state ~mode (`Fname fname))
| "Lwt_io", "read_line" ->
take @@ fun in_chan -> return (Some (backend#io_read_line in_chan))
| "Lwt_io", "read" ->
take_lblopt "count" @@ fun count ->
take @@ fun in_chan -> return (lwt_io_read ~backend ~state count in_chan)
| "Lwt_io", "write" ->
take @@ fun chan ->
take @@ fun str -> return (Some (backend#io_write_str chan str))
| "Lwt_io", "length" -> take @@ fun fd -> return (Some (backend#io_length fd))
| "Lwt_io", "close" -> take @@ fun fd -> return (Some (backend#io_close fd))
| "Lwt_main", "run" ->
take @@ fun promise -> return (Some (backend#main_run promise))
| _ -> return None
Expand Down
160 changes: 122 additions & 38 deletions bin/lwt_to_direct_style/concurrency_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ let eio ~eio_sw_as_fiber_var ~eio_env_as_fiber_var add_comment =
in
let fiber_ident = eio_std_ident "Fiber"
and promise_ident = eio_std_ident "Promise"
and switch_ident = eio_std_ident "Switch" in
and switch_ident = eio_std_ident "Switch"
and std_ident i =
used_eio_std := true;
[ i ]
in
let add_comment fmt = Format.kasprintf add_comment fmt in
let add_comment_dropped_exp ~label exp =
add_comment "Dropped expression (%s): [%s]." label
Expand Down Expand Up @@ -47,6 +51,39 @@ let eio ~eio_sw_as_fiber_var ~eio_env_as_fiber_var add_comment =
in
Exp.send env_exp (mk_loc field)
in
let buf_read_of_flow flow =
mk_apply_ident
[ "Eio"; "Buf_read"; "of_flow" ]
[
(Labelled (mk_loc "max_size"), mk_const_int "1_000_000"); (Nolabel, flow);
]
in
let buf_write_of_flow flow =
add_comment
"Write operations to buffered IO should be moved inside [with_flow].";
mk_apply_simple
[ "Eio"; "Buf_write"; "with_flow" ]
[
flow;
mk_fun ~arg_name:"outbuf" (fun _outbuf ->
mk_variant_exp "Move_writing_code_here");
]
in
let import_socket_stream ~r_or_w fd =
(* Used by [input_io] and [output_io]. *)
Exp.constraint_
(mk_apply_ident
[ "Eio_unix"; "Net"; "import_socket_stream" ]
[
get_current_switch_arg ();
(Labelled (mk_loc "close_unix"), mk_constr_exp [ "true" ]);
(Nolabel, fd);
])
(mk_typ_constr
~params:
[ mk_poly_variant [ (r_or_w, []); ("Flow", []); ("Close", []) ] ]
(std_ident "r"))
in
object
method both ~left ~right =
mk_apply_simple (fiber_ident "pair") [ left; right ]
Expand Down Expand Up @@ -150,31 +187,51 @@ let eio ~eio_sw_as_fiber_var ~eio_env_as_fiber_var add_comment =

method direct_style_type param = param

method of_unix_file_descr ?blocking fd =
let blocking_arg =
let lbl = mk_loc "blocking" in
match blocking with
| Some (expr, `Lbl) -> [ (Labelled lbl, expr) ]
| Some (expr, `Opt) -> [ (Optional lbl, expr) ]
| None -> []
in
mk_apply_ident
[ "Eio_unix"; "Fd"; "of_unix" ]
([ get_current_switch_arg () ]
@ blocking_arg
@ [
(Labelled (mk_loc "close_unix"), mk_constr_exp [ "true" ]);
(Nolabel, fd);
])
method of_unix_file_descr ?blocking:_ fd =
(* TODO: We don't use [Eio_unix.Fd.t] because there is no conversion to [Flow.sink]. *)
(* let blocking_arg = *)
(* let lbl = mk_loc "blocking" in *)
(* match blocking with *)
(* | Some (expr, `Lbl) -> [ (Labelled lbl, expr) ] *)
(* | Some (expr, `Opt) -> [ (Optional lbl, expr) ] *)
(* | None -> [] *)
(* in *)
(* mk_apply_ident *)
(* [ "Eio_unix"; "Fd"; "of_unix" ] *)
(* ([ get_current_switch_arg () ] *)
(* @ blocking_arg *)
(* @ [ *)
(* (Labelled (mk_loc "close_unix"), mk_constr_exp [ "true" ]); *)
(* (Nolabel, fd); *)
(* ]) *)
fd

method io_read input buffer buf_offset buf_len =
add_comment "[%s] should be a [Cstruct.t]."
(Ocamlformat_utils.format_expression buffer);
add_comment
"[Eio.Flow.single_read] operates on a [Flow.source] but [%s] is likely \
of type [Eio.Buf_read.t]. Rewrite this code to use [Buf_read] (which \
contains an internal buffer) or change the call to \
[Eio.Buf_read.of_flow] used to create the buffer."
(Ocamlformat_utils.format_expression input);
add_comment_dropped_exp ~label:"buffer offset" buf_offset;
add_comment_dropped_exp ~label:"buffer length" buf_len;
mk_apply_simple [ "Eio"; "Flow"; "single_read" ] [ input; buffer ]

method fd_close fd = mk_apply_simple [ "Eio_unix"; "Fd" ] [ fd ]
method io_read_all input =
mk_apply_simple [ "Eio"; "Buf_read"; "take_all" ] [ input ]

method io_read_string_count _input _count_arg =
add_comment
"Eio doesn't have a direct equivalent of [Lwt_io.read ~count]. Rewrite \
the code using [Eio.Buf_read]'s lower level API or switch to \
unbuffered IO.";
None

method fd_close fd =
(* TODO: See [of_unix_file_descr]. mk_apply_simple [ "Eio_unix"; "Fd" ] [ fd ] *)
mk_apply_simple [ "Unix"; "close" ] [ fd ]

method main_run promise =
let with_binding var_ident x body =
Expand Down Expand Up @@ -209,29 +266,56 @@ let eio ~eio_sw_as_fiber_var ~eio_env_as_fiber_var add_comment =
wrap_env_fiber_var env (wrap_sw_fiber_var promise));
]

method input_io_of_fd fd =
Exp.constraint_
(mk_apply_simple [ "Eio_unix"; "Net"; "import_socket_stream" ] [ fd ])
(mk_typ_constr
~params:
[ mk_poly_variant [ ("R", []); ("Flow", []); ("Close", []) ] ]
[ "Std"; "r" ])

method output_io_of_fd fd =
add_comment
"This creates a closeable [Flow.sink] resource but write operations \
are rewritten to calls to [Buf_write].\n\
\ You might want to use [Buf_write.with_flow sink (fun \
buf_write -> ...)].";
Exp.constraint_
(mk_apply_simple [ "Eio_unix"; "Net"; "import_socket_stream" ] [ fd ])
(mk_typ_constr
~params:
[ mk_poly_variant [ ("W", []); ("Flow", []); ("Close", []) ] ]
[ "Std"; "r" ])
method input_io =
function
| `Of_fd fd -> buf_read_of_flow (import_socket_stream ~r_or_w:"R" fd)
| `Fname fname ->
buf_read_of_flow
@@ mk_apply_ident
[ "Eio"; "Path"; "open_in" ]
[
get_current_switch_arg ();
( Nolabel,
mk_apply_simple [ "Eio"; "Path"; "/" ] [ env "cwd"; fname ]
);
]

method output_io =
function
| `Of_fd fd -> buf_write_of_flow (import_socket_stream ~r_or_w:"W" fd)
| `Fname fname ->
add_comment
"[flags] and [perm] arguments were dropped. The [~create] was \
added by default and might not match the previous flags. Use \
[~append:true] for [O_APPEND].";
buf_write_of_flow
@@ mk_apply_ident
[ "Eio"; "Path"; "open_out" ]
[
get_current_switch_arg ();
( Labelled (mk_loc "create"),
mk_variant_exp ~arg:(mk_const_int "0o666") "If_missing" );
( Nolabel,
mk_apply_simple [ "Eio"; "Path"; "/" ] [ env "cwd"; fname ]
);
]

method io_read_line chan =
mk_apply_simple [ "Eio"; "Buf_read"; "line" ] [ chan ]

(* This is of type [Optint.Int63.t] instead of [int] with Lwt. *)
method io_length fd = mk_apply_simple [ "Eio"; "File"; "size" ] [ fd ]

method io_write_str chan str =
mk_apply_simple [ "Eio"; "Buf_write"; "string" ] [ chan; str ]

method io_close fd = mk_apply_simple [ "Eio"; "Resource"; "close" ] [ fd ]
method type_out_channel = mk_typ_constr [ "Eio"; "Buf_write"; "t" ]

method path_stat ~follow path =
mk_apply_ident [ "Eio"; "Path"; "stat" ]
[
(Labelled (mk_loc "follow"), mk_constr_of_bool follow);
(Nolabel, mk_apply_simple [ "Eio"; "Path"; "/" ] [ env "cwd"; path ]);
]
end
2 changes: 2 additions & 0 deletions lib/ocamlformat_utils/ast_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ let mk_poly_variant ?(open_ = false) ?labels ?(inherit_ = []) vars =
let constrs = List.map mk_rtag vars @ List.map mk_inherit inherit_ in
Typ.variant constrs flag labels

let mk_constr_of_bool b = mk_constr_exp [ (if b then "true" else "false") ]

(* Exp *)

let mk_const_string s = Exp.constant (Const.string s)
Expand Down
32 changes: 15 additions & 17 deletions test/lwt_to_direct_style/eio-switch.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ Make a writable directory tree:
$ dune build @ocaml-index
$ lwt-to-direct-style --migrate --eio-sw-as-fiber-var Fiber_var.sw --eio-env-as-fiber-var Fiber_var.env
Formatted 1 files
Warning: main.ml: 5 occurrences have not been rewritten.
Lwt_io.open_file (line 8 column 13)
Lwt_io.input (line 8 column 36)
Lwt_io.close (line 9 column 3)
Lwt_io.read (line 15 column 12)
Warning: main.ml: 1 occurrences have not been rewritten.
Lwt_io.printf (line 16 column 3)

$ cat main.ml
Expand All @@ -22,25 +18,27 @@ Make a writable directory tree:
(Stdlib.Option.get (Fiber.get Fiber_var.env))#mono_clock 1.0 (fun () -> 42)

let _f fname =
let fd = Lwt_io.open_file ~mode:Lwt_io.input fname in
Lwt_io.close fd
let fd =
Eio.Buf_read.of_flow ~max_size:1_000_000
(Eio.Path.open_in
~sw:(Stdlib.Option.get (Fiber.get Fiber_var.sw))
(Eio.Path.( / ) (Stdlib.Option.get (Fiber.get Fiber_var.env))#cwd fname))
in
Eio.Resource.close fd

let main () =
Fiber.fork
~sw:(Stdlib.Option.get (Fiber.get Fiber_var.sw))
(fun () -> async_process 1);
let fd =
fun ?blocking:x1 ?set_flags:x2 ->
Eio_unix.Fd.of_unix
~sw:(Stdlib.Option.get (Fiber.get Fiber_var.sw))
?blocking:x1 ~close_unix:true
(* TODO: lwt-to-direct-style: Labelled argument ?set_flags was dropped. *)
Unix.stdin
in
let fd = Unix.stdin in
let in_chan =
(Eio_unix.Net.import_socket_stream fd : [ `R | `Flow | `Close ] Std.r)
Eio.Buf_read.of_flow ~max_size:1_000_000
(Eio_unix.Net.import_socket_stream
~sw:(Stdlib.Option.get (Fiber.get Fiber_var.sw))
~close_unix:true fd
: [ `R | `Flow | `Close ] r)
in
let s = Lwt_io.read in_chan in
let s = Eio.Buf_read.take_all in_chan in
Lwt_io.printf "%s" s

let () =
Expand Down
Loading