This safe wrapper around Unix.openfile ensures that exceptions
escaping cannot leave unclosed files.
There are only a few places in the code where this wrapper can be used
currently. There are other occurences of Unix.openfile but they are
not suitable for replacement.
---
common/mlstdutils/std_utils.ml | 4 ++++
common/mlstdutils/std_utils.mli | 6 ++++++
daemon/devsparts.ml | 5 ++---
daemon/inspect_fs_windows.ml | 18 ++++++++----------
4 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/common/mlstdutils/std_utils.ml b/common/mlstdutils/std_utils.ml
index ee6bea5af..32944ed27 100644
--- a/common/mlstdutils/std_utils.ml
+++ b/common/mlstdutils/std_utils.ml
@@ -662,6 +662,10 @@ let with_open_out filename f =
let chan = open_out filename in
protect ~f:(fun () -> f chan) ~finally:(fun () -> close_out chan)
+let with_openfile filename flags perms =
+ let fd = Unix.openfile filename flags perms in
+ protect ~f:(fun () -> f fd) ~finally:(fun () -> close fd)
+
let read_whole_file path =
let buf = Buffer.create 16384 in
with_open_in path (
diff --git a/common/mlstdutils/std_utils.mli b/common/mlstdutils/std_utils.mli
index 7af6c2111..178762819 100644
--- a/common/mlstdutils/std_utils.mli
+++ b/common/mlstdutils/std_utils.mli
@@ -399,6 +399,12 @@ val with_open_out : string -> (out_channel -> 'a) ->
'a
return or if the function [f] throws an exception, so this is
both safer and more concise than the regular function. *)
+val with_openfile : string -> Unix.open_flag list -> Unix.file_perm ->
(Unix.file_desc -> 'a) -> 'a
+(** [with_openfile] calls function [f] with [filename] opened by the
+ {!Unix.openfile} function. The file is always closed either on
+ normal return or if the function [f] throws an exception, so this
+ is both safer and more concise than the regular function. *)
+
val read_whole_file : string -> string
(** Read in the whole file as a string. *)
diff --git a/daemon/devsparts.ml b/daemon/devsparts.ml
index 7395de923..0eb7c1282 100644
--- a/daemon/devsparts.ml
+++ b/daemon/devsparts.ml
@@ -49,9 +49,8 @@ let map_block_devices ~return_md f =
List.filter (
fun dev ->
try
- let fd = openfile ("/dev/" ^ dev) [O_RDONLY; O_CLOEXEC] 0 in
- close fd;
- true
+ with_openfile ("/dev/" ^ dev) [O_RDONLY; O_CLOEXEC] 0
+ (fun _ -> true)
with _ -> false
) devs in
diff --git a/daemon/inspect_fs_windows.ml b/daemon/inspect_fs_windows.ml
index 7c42fc5d7..112cc2f92 100644
--- a/daemon/inspect_fs_windows.ml
+++ b/daemon/inspect_fs_windows.ml
@@ -429,16 +429,14 @@ and extract_guid_from_registry_blob blob =
(data4 &^ 0xffffffffffff_L)
and pread device size offset =
- let fd = Unix.openfile device [Unix.O_RDONLY; Unix.O_CLOEXEC] 0 in
- let ret =
- protect ~f:(
- fun () ->
- ignore (Unix.lseek fd offset Unix.SEEK_SET);
- let ret = Bytes.create size in
- if Unix.read fd ret 0 size < size then
- failwithf "pread: %s: short read" device;
- ret
- ) ~finally:(fun () -> Unix.close fd) in
+ with_openfile device [Unix.O_RDONLY; Unix.O_CLOEXEC] 0 (
+ fun fd ->
+ ignore (Unix.lseek fd offset Unix.SEEK_SET);
+ let ret = Bytes.create size in
+ if Unix.read fd ret 0 size < size then
+ failwithf "pread: %s: short read" device;
+ ret
+ );
Bytes.to_string ret
(* Get the hostname. *)
--
2.13.2