Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/tcp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ module type S = sig
val writev_nodelay: flow -> Cstruct.t list -> (unit, write_error) result Lwt.t
val create_connection: ?keepalive:Keepalive.t -> t -> ipaddr * int -> (flow, error) result Lwt.t
val listen : t -> port:int -> ?keepalive:Keepalive.t -> (flow -> unit Lwt.t) -> unit
val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
end
8 changes: 6 additions & 2 deletions src/core/tcp.mli
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ module type S = sig
executed for each flow that was established. If [keepalive] is provided,
this configuration will be applied before calling [callback].

@raise Invalid_argument if [port < 0] or [port > 65535]
*)
@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> (flow -> unit Lwt.t) option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listener on [port]. *)
Expand Down
1 change: 1 addition & 0 deletions src/core/udp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module type S = sig
val disconnect : t -> unit Lwt.t
type callback = src:ipaddr -> dst:ipaddr -> src_port:int -> Cstruct.t -> unit Lwt.t
val listen : t -> port:int -> callback -> unit
val is_listening : t -> port:int -> callback option
val unlisten : t -> port:int -> unit
val input: t -> src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t
val write: ?src:ipaddr -> ?src_port:int -> ?ttl:int -> dst:ipaddr -> dst_port:int -> t -> Cstruct.t ->
Expand Down
5 changes: 5 additions & 0 deletions src/core/udp.mli
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ module type S = sig

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val is_listening : t -> port:int -> callback option
(** [is_listening t ~port] returns the [callback] on [port], if it exists.

@raise Invalid_argument if [port < 0] or [port > 65535] *)

val unlisten : t -> port:int -> unit
(** [unlisten t ~port] stops any listeners on [port]. *)

Expand Down
11 changes: 7 additions & 4 deletions src/stack-unix/tcpv4v6_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type flow = Lwt_unix.file_descr
type t = {
interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *)
mutable active_connections : Lwt_unix.file_descr list;
listen_sockets : (int, Lwt_unix.file_descr list) Hashtbl.t;
listen_sockets : (int, Lwt_unix.file_descr list * (flow -> unit Lwt.t)) Hashtbl.t;
mutable switched_off : unit Lwt.t;
}

Expand Down Expand Up @@ -63,7 +63,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =
let disconnect t =
Lwt_list.iter_p close t.active_connections >>= fun () ->
Lwt_list.iter_p close
(Hashtbl.fold (fun _ fd acc -> fd @ acc) t.listen_sockets []) >>= fun () ->
(Hashtbl.fold (fun _ (fds, _) acc -> fds @ acc) t.listen_sockets []) >>= fun () ->
Lwt.cancel t.switched_off ; Lwt.return_unit

let dst fd =
Expand Down Expand Up @@ -113,10 +113,13 @@ let create_connection ?keepalive t (dst,dst_port) =
let unlisten t ~port =
match Hashtbl.find_opt t.listen_sockets port with
| None -> ()
| Some fds ->
| Some (fds, _) ->
Hashtbl.remove t.listen_sockets port;
try List.iter (fun fd -> Unix.close (Lwt_unix.unix_file_descr fd)) fds with _ -> ()

let is_listening t ~port =
Option.map snd (Hashtbl.find_opt t.listen_sockets port)

let listen t ~port ?keepalive callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port));
Expand Down Expand Up @@ -147,7 +150,7 @@ let listen t ~port ?keepalive callback =
in
List.iter (fun (fd, addr) ->
Unix.bind (Lwt_unix.unix_file_descr fd) addr;
Hashtbl.replace t.listen_sockets port (List.map fst fds);
Hashtbl.replace t.listen_sockets port (List.map fst fds, callback);
Lwt_unix.listen fd 10;
(* FIXME: we should not ignore the result *)
Lwt.async (fun () ->
Expand Down
26 changes: 16 additions & 10 deletions src/stack-unix/udpv4v6_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ let any_v6 = Ipaddr_unix.V6.to_inet_addr Ipaddr.V6.unspecified

type t = {
interface: [ `Any | `Ip of Unix.inet_addr * Unix.inet_addr | `V4_only of Unix.inet_addr | `V6_only of Unix.inet_addr ]; (* source ip to bind to *)
listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option) Hashtbl.t; (* UDP fds bound to a particular port *)
listen_fds: (int, Lwt_unix.file_descr * Lwt_unix.file_descr option * callback) Hashtbl.t; (* UDP fds bound to a particular port *)
mutable switched_off : unit Lwt.t;
}

Expand All @@ -38,12 +38,12 @@ let ignore_canceled = function
| Lwt.Canceled -> Lwt.return_unit
| exn -> raise exn

let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;interface;_} port =
let get_udpv4v6_listening_fd ?preserve ?(v4_or_v6 = `Both) {listen_fds;interface;_} port =
try
Lwt.return
(match Hashtbl.find listen_fds port with
| (fd, None) -> false, [ fd ]
| (fd, Some fd') -> false, [ fd ; fd' ])
| (fd, None, _) -> false, [ fd ]
| (fd, Some fd', _) -> false, [ fd ; fd' ])
with Not_found ->
(match interface with
| `Any ->
Expand Down Expand Up @@ -76,8 +76,8 @@ let get_udpv4v6_listening_fd ?(preserve = true) ?(v4_or_v6 = `Both) {listen_fds;
| `V6_only ip ->
let fd = Lwt_unix.(socket PF_INET6 SOCK_DGRAM 0) in
Lwt_unix.bind fd (Lwt_unix.ADDR_INET (ip, port)) >|= fun () ->
((fd, None), [ fd ])) >|= fun (fds, r) ->
if preserve then Hashtbl.add listen_fds port fds;
((fd, None), [ fd ])) >|= fun ((fd1, fd2), r) ->
Option.iter (fun cb -> Hashtbl.add listen_fds port (fd1, fd2, cb)) preserve;
true, r


Expand Down Expand Up @@ -121,7 +121,7 @@ let connect ~ipv4_only ~ipv6_only ipv4 ipv6 =
Lwt.return { interface; listen_fds; switched_off = fst (Lwt.wait ()) }

let disconnect t =
Hashtbl.fold (fun _ (fd, fd') r ->
Hashtbl.fold (fun _ (fd, fd', _) r ->
r >>= fun () ->
close fd >>= fun () ->
match fd' with None -> Lwt.return_unit | Some fd -> close fd)
Expand All @@ -146,7 +146,7 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf =
match t.interface, v4_or_v6 with
| `Any, _ | `Ip _, _ | `V4_only _, `V4 | `V6_only _, `V6 ->
let p = match src_port with None -> 0 | Some x -> x in
get_udpv4v6_listening_fd ~preserve:false ~v4_or_v6 t p >>= fun (created, fds) ->
get_udpv4v6_listening_fd ~v4_or_v6 t p >>= fun (created, fds) ->
((match fds, v4_or_v6 with
| [ fd ], _ -> Lwt.return (Ok fd)
| [ v4 ; _v6 ], `V4 -> Lwt.return (Ok v4)
Expand All @@ -161,19 +161,25 @@ let write ?src:_ ?src_port ?ttl:_ttl ~dst ~dst_port t buf =

let unlisten t ~port =
try
let fd, fd' = Hashtbl.find t.listen_fds port in
let fd, fd', _ = Hashtbl.find t.listen_fds port in
Hashtbl.remove t.listen_fds port;
(match fd' with None -> () | Some fd' -> Unix.close (Lwt_unix.unix_file_descr fd'));
Unix.close (Lwt_unix.unix_file_descr fd)
with _ -> ()

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Option.map (fun (_, _, cb) -> cb) (Hashtbl.find_opt t.listen_fds port)

let listen t ~port callback =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
(* FIXME: we should not ignore the result *)
Lwt.async (fun () ->
get_udpv4v6_listening_fd t port >|= fun (_, fds) ->
get_udpv4v6_listening_fd ~preserve:callback t port >|= fun (_, fds) ->
List.iter (fun fd ->
Lwt.async (fun () ->
let buf = Cstruct.create 4096 in
Expand Down
6 changes: 6 additions & 0 deletions src/tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ struct
else
Hashtbl.replace t.listeners port (keepalive, cb)

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Option.map snd (Hashtbl.find_opt t.listeners port)

let unlisten t ~port = Hashtbl.remove t.listeners port

let _pp_pcb fmt pcb =
Expand Down
6 changes: 6 additions & 0 deletions src/udp/udp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ module Make (Ip : Tcpip.Ip.S) (Random : Mirage_random.S) = struct
else
Hashtbl.replace t.listeners port callback

let is_listening t ~port =
if port < 0 || port > 65535 then
raise (Invalid_argument (Printf.sprintf "invalid port number (%d)" port))
else
Hashtbl.find_opt t.listeners port

let unlisten t ~port = Hashtbl.remove t.listeners port

(* TODO: ought we to check to make sure the destination is relevant
Expand Down