From 8dc43488a5a801760edfc84e91df727f1b088260 Mon Sep 17 00:00:00 2001 From: Alisina Bahadori Date: Sat, 2 Nov 2024 22:42:10 -0400 Subject: [PATCH 1/3] Add PrimitiveOps behaviour and default impl --- lib/bandit.ex | 2 +- lib/bandit/extractor.ex | 14 +- lib/bandit/primitive_ops.ex | 10 ++ lib/bandit/primitive_ops/default.ex | 29 ++++ lib/bandit/websocket/frame.ex | 48 ++----- lib/bandit/websocket/handler.ex | 5 +- .../websocket/frame_deserialization_test.exs | 128 +++++++++--------- test/support/simple_websocket_client.ex | 4 +- 8 files changed, 136 insertions(+), 104 deletions(-) create mode 100644 lib/bandit/primitive_ops.ex create mode 100644 lib/bandit/primitive_ops/default.ex diff --git a/lib/bandit.ex b/lib/bandit.ex index 36d99221..590f4b6d 100644 --- a/lib/bandit.ex +++ b/lib/bandit.ex @@ -220,7 +220,7 @@ defmodule Bandit do @top_level_keys ~w(plug scheme port ip keyfile certfile otp_app cipher_suite display_plug startup_log thousand_island_options http_options http_1_options http_2_options websocket_options)a @http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a - @http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a + @http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages primitive_ops_module)a @http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a @websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a @thousand_island_keys ThousandIsland.ServerConfig.__struct__() diff --git a/lib/bandit/extractor.ex b/lib/bandit/extractor.ex index 2ee206ec..09065a89 100644 --- a/lib/bandit/extractor.ex +++ b/lib/bandit/extractor.ex @@ -9,7 +9,7 @@ defmodule Bandit.Extractor do | {:error, term()} | :more - @callback deserialize(binary()) :: deserialize_result() + @callback deserialize(binary(), primitive_ops_module :: module()) :: deserialize_result() @type t :: %__MODULE__{ header: binary(), @@ -18,7 +18,8 @@ defmodule Bandit.Extractor do required_length: non_neg_integer(), mode: :header_parsing | :payload_parsing, max_frame_size: non_neg_integer(), - frame_parser: atom() + frame_parser: atom(), + primitive_ops_module: module() } defstruct header: <<>>, @@ -27,15 +28,18 @@ defmodule Bandit.Extractor do required_length: 0, mode: :header_parsing, max_frame_size: 0, - frame_parser: nil + frame_parser: nil, + primitive_ops_module: nil @spec new(module(), Keyword.t()) :: t() def new(frame_parser, opts) do max_frame_size = Keyword.get(opts, :max_frame_size, 0) + primitive_ops_module = Keyword.get(opts, :primitive_ops_module) || Bandit.PrimitiveOps.Default %__MODULE__{ max_frame_size: max_frame_size, - frame_parser: frame_parser + frame_parser: frame_parser, + primitive_ops_module: primitive_ops_module } end @@ -79,7 +83,7 @@ defmodule Bandit.Extractor do <> = IO.iodata_to_binary(state.payload) - frame = state.frame_parser.deserialize(state.header <> payload) + frame = state.frame_parser.deserialize(state.header <> payload, state.primitive_ops_module) state = transition_to_header_parsing(state, rest) {state, frame} diff --git a/lib/bandit/primitive_ops.ex b/lib/bandit/primitive_ops.ex new file mode 100644 index 00000000..dee80917 --- /dev/null +++ b/lib/bandit/primitive_ops.ex @@ -0,0 +1,10 @@ +defmodule Bandit.PrimitiveOps do + @moduledoc """ + Primitive operations behaviour + """ + + @doc """ + WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3) + """ + @callback ws_mask(payload :: binary(), mask :: integer()) :: binary() +end diff --git a/lib/bandit/primitive_ops/default.ex b/lib/bandit/primitive_ops/default.ex new file mode 100644 index 00000000..9fba223c --- /dev/null +++ b/lib/bandit/primitive_ops/default.ex @@ -0,0 +1,29 @@ +defmodule Bandit.PrimitiveOps.Default do + @moduledoc """ + Default implementation of `Bandit.PrimitiveOps` + """ + + @behaviour Bandit.PrimitiveOps + + # Note that masking is an involution, so we don't need a separate unmask function + @impl Bandit.PrimitiveOps + def ws_mask(payload, mask) + when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do + ws_mask(<<>>, payload, mask) + end + + defp ws_mask(acc, <>, mask) do + ws_mask(<>)>>, rest, mask) + end + + for size <- [24, 16, 8] do + defp ws_mask(acc, <>, mask) do + <> = <> + <>)>> + end + end + + defp ws_mask(acc, <<>>, _mask) do + acc + end +end diff --git a/lib/bandit/websocket/frame.ex b/lib/bandit/websocket/frame.ex index 8d9fbe46..1d5b42c7 100644 --- a/lib/bandit/websocket/frame.ex +++ b/lib/bandit/websocket/frame.ex @@ -73,29 +73,32 @@ defmodule Bandit.WebSocket.Frame do end @impl Bandit.Extractor - @spec deserialize(binary()) :: {:ok, frame()} | {:error, term()} + @spec deserialize(binary(), module()) :: {:ok, frame()} | {:error, term()} def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end - def deserialize(_msg) do + def deserialize(_msg, _primitive_ops_module) do {:error, :deserialization_failed} end @@ -155,14 +158,15 @@ defmodule Bandit.WebSocket.Frame do end end - defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do + defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, _primitive_ops_module) + when rsv != 0x0 do {:error, "Received unsupported RSV flags #{rsv}"} end - defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do + defp to_frame(fin, compressed, 0x0, opcode, mask, payload, primitive_ops_module) do fin = fin == 0x1 compressed = compressed == 0x1 - unmasked_payload = mask(payload, mask) + unmasked_payload = primitive_ops_module.ws_mask(payload, mask) opcode |> case do @@ -198,26 +202,4 @@ defmodule Bandit.WebSocket.Frame do defp mask_and_length(length) when length <= 125, do: <<0::1, length::7>> defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>> defp mask_and_length(length), do: <<0::1, 127::7, length::64>> - - # Note that masking is an involution, so we don't need a separate unmask function - @spec mask(binary(), integer()) :: binary() - def mask(payload, mask) - when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do - mask(<<>>, payload, mask) - end - - defp mask(acc, <>, mask) do - mask(<>)>>, rest, mask) - end - - for size <- [24, 16, 8] do - defp mask(acc, <>, mask) do - <> = <> - <>)>> - end - end - - defp mask(acc, <<>>, _mask) do - acc - end end diff --git a/lib/bandit/websocket/handler.ex b/lib/bandit/websocket/handler.ex index 3d69c0fd..919aed3c 100644 --- a/lib/bandit/websocket/handler.ex +++ b/lib/bandit/websocket/handler.ex @@ -15,7 +15,10 @@ defmodule Bandit.WebSocket.Handler do |> Keyword.take([:fullsweep_after, :max_heap_size]) |> Enum.each(fn {key, value} -> :erlang.process_flag(key, value) end) - connection_opts = Keyword.merge(state.opts.websocket, connection_opts) + connection_opts = + state.opts.websocket + |> Keyword.merge(connection_opts) + |> Keyword.put(:primitive_ops_module, Keyword.get(state.opts.http_1, :primitive_ops_module)) state = state diff --git a/test/bandit/websocket/frame_deserialization_test.exs b/test/bandit/websocket/frame_deserialization_test.exs index e2be97a5..9daabc42 100644 --- a/test/bandit/websocket/frame_deserialization_test.exs +++ b/test/bandit/websocket/frame_deserialization_test.exs @@ -1,100 +1,103 @@ defmodule WebSocketFrameDeserializationTest do use ExUnit.Case, async: true - import Bandit.WebSocket.Frame, only: [mask: 2] + import Bandit.PrimitiveOps.Default, only: [ws_mask: 2] + alias Bandit.PrimitiveOps.Default, as: DefaultPrimitiveOps alias Bandit.WebSocket.Frame describe "reserved flag parsing" do test "errors on reserved flag 1 being set" do frame = <<0x1::1, 0x1::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 1"} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == + {:error, "Received unsupported RSV flags 1"} end test "errors on reserved flag 2 being set" do frame = <<0x1::1, 0x2::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 2"} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == + {:error, "Received unsupported RSV flags 2"} end end describe "frame size" do test "parses 2 byte frames" do payload = String.duplicate("a", 2) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 2::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses 10 byte frames" do payload = String.duplicate("a", 10) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 10::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames up to 125 bytes" do payload = String.duplicate("a", 125) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 126 bytes long" do payload = String.duplicate("a", 126) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 126::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 127 bytes long" do payload = String.duplicate("a", 127) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 127::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 16_000 bytes long" do payload = String.duplicate("a", 16_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 1_000_000 bytes long" do payload = String.duplicate("a", 1_000_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "errors on frames over max_frame_size bytes with small frames" do payload = String.duplicate("a", 125) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> @@ -104,7 +107,7 @@ defmodule WebSocketFrameDeserializationTest do test "errors on frames over max_frame_size bytes with medium frames" do payload = String.duplicate("a", 16_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> @@ -115,7 +118,7 @@ defmodule WebSocketFrameDeserializationTest do test "errors on frames over max_frame_size bytes with large frames" do payload = String.duplicate("a", 1_000_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> @@ -129,7 +132,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, :deserialization_failed} end end @@ -137,7 +140,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 1::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, :deserialization_failed} end end @@ -145,7 +148,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns an Unknown frame" do frame = <<0x1::1, 0x0::3, 0xF::4, 1::1, 1::7, 0::32, 1>> - assert Frame.deserialize(frame) == {:error, "unknown opcode #{15}"} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "unknown opcode #{15}"} end end @@ -153,27 +156,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Continuation{fin: true, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Continuation{fin: false, data: <<1, 2, 3, 4, 5>>}} end test "refuses frame with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a compressed continuation frame (RFC7692§6.1)"} end end @@ -182,27 +185,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin and per-message compressed bits clear" do frame = <<0x0::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -211,27 +214,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin and per-message compressed bits clear" do frame = <<0x0::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Binary{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -242,30 +245,31 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 125::7, 0x01020304::32, - mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> + ws_mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.ConnectionClose{code: 1000, reason: payload}} end test "deserializes frames with code" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 2::7, 0x01020304::32, - mask(<<1000::16>>, 0x01020304)::binary>> + ws_mask(<<1000::16>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{code: 1000}} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == + {:ok, %Frame.ConnectionClose{code: 1000}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{}} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.ConnectionClose{}} end test "refuses frame with invalid payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 1::7, 0x01020304::32, 1>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end @@ -273,21 +277,21 @@ defmodule WebSocketFrameDeserializationTest do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a fragmented connection close frame (RFC6455§5.5)"} end test "refuses frame with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a compressed connection close frame (RFC7692§6.1)"} end end @@ -298,37 +302,37 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 125::7, 0x01020304::32, - mask(payload, 0x01020304)::binary>> + ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Ping{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.Ping{}} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Ping{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Invalid ping payload (RFC6455§5.5.2)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a fragmented ping frame (RFC6455§5.5.2)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a compressed ping frame (RFC7692§6.1)"} end end @@ -339,37 +343,37 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 125::7, 0x01020304::32, - mask(payload, 0x01020304)::binary>> + ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Pong{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.Pong{}} + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Pong{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Invalid pong payload (RFC6455§5.5.3)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a fragmented pong frame (RFC6455§5.5.3)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "Cannot have a compressed pong frame (RFC7692§6.1)"} end end diff --git a/test/support/simple_websocket_client.ex b/test/support/simple_websocket_client.ex index 07938dbe..faea81ba 100644 --- a/test/support/simple_websocket_client.ex +++ b/test/support/simple_websocket_client.ex @@ -1,7 +1,7 @@ defmodule SimpleWebSocketClient do @moduledoc false - alias Bandit.WebSocket.Frame + alias Bandit.PrimitiveOps defdelegate tcp_client(context), to: Transport @@ -140,7 +140,7 @@ defmodule SimpleWebSocketClient do defp send_frame(client, flags, opcode, data) do mask = :rand.uniform(1_000_000) - masked_data = Frame.mask(data, mask) + masked_data = PrimitiveOps.Default.ws_mask(data, mask) mask_flag_and_size = case byte_size(masked_data) do From 053ee99fa222e39d9ab979156bc71685efe4ddda Mon Sep 17 00:00:00 2001 From: Alisina Bahadori Date: Thu, 14 Nov 2024 20:38:25 -0500 Subject: [PATCH 2/3] Move into websocket options --- lib/bandit.ex | 4 +- lib/bandit/extractor.ex | 2 +- lib/bandit/primitive_ops.ex | 10 --- .../{default.ex => websocket.ex} | 13 ++- lib/bandit/websocket/handler.ex | 2 +- .../websocket/frame_deserialization_test.exs | 80 +++++++++---------- test/support/simple_websocket_client.ex | 4 +- 7 files changed, 55 insertions(+), 60 deletions(-) delete mode 100644 lib/bandit/primitive_ops.ex rename lib/bandit/primitive_ops/{default.ex => websocket.ex} (66%) diff --git a/lib/bandit.ex b/lib/bandit.ex index 590f4b6d..992d2f19 100644 --- a/lib/bandit.ex +++ b/lib/bandit.ex @@ -220,9 +220,9 @@ defmodule Bandit do @top_level_keys ~w(plug scheme port ip keyfile certfile otp_app cipher_suite display_plug startup_log thousand_island_options http_options http_1_options http_2_options websocket_options)a @http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a - @http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages primitive_ops_module)a + @http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a @http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a - @websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a + @websocket_keys ~w(enabled max_frame_size validate_text_frames compress primitive_ops_module)a @thousand_island_keys ThousandIsland.ServerConfig.__struct__() |> Map.from_struct() |> Map.keys() diff --git a/lib/bandit/extractor.ex b/lib/bandit/extractor.ex index 09065a89..c470eb67 100644 --- a/lib/bandit/extractor.ex +++ b/lib/bandit/extractor.ex @@ -34,7 +34,7 @@ defmodule Bandit.Extractor do @spec new(module(), Keyword.t()) :: t() def new(frame_parser, opts) do max_frame_size = Keyword.get(opts, :max_frame_size, 0) - primitive_ops_module = Keyword.get(opts, :primitive_ops_module) || Bandit.PrimitiveOps.Default + primitive_ops_module = Keyword.fetch!(opts, :primitive_ops_module) %__MODULE__{ max_frame_size: max_frame_size, diff --git a/lib/bandit/primitive_ops.ex b/lib/bandit/primitive_ops.ex deleted file mode 100644 index dee80917..00000000 --- a/lib/bandit/primitive_ops.ex +++ /dev/null @@ -1,10 +0,0 @@ -defmodule Bandit.PrimitiveOps do - @moduledoc """ - Primitive operations behaviour - """ - - @doc """ - WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3) - """ - @callback ws_mask(payload :: binary(), mask :: integer()) :: binary() -end diff --git a/lib/bandit/primitive_ops/default.ex b/lib/bandit/primitive_ops/websocket.ex similarity index 66% rename from lib/bandit/primitive_ops/default.ex rename to lib/bandit/primitive_ops/websocket.ex index 9fba223c..d9af1672 100644 --- a/lib/bandit/primitive_ops/default.ex +++ b/lib/bandit/primitive_ops/websocket.ex @@ -1,12 +1,17 @@ -defmodule Bandit.PrimitiveOps.Default do +defmodule Bandit.PrimitiveOps.WebSocket do @moduledoc """ - Default implementation of `Bandit.PrimitiveOps` + WebSocket primitive operations behaviour and default implementation """ - @behaviour Bandit.PrimitiveOps + @doc """ + WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3) + """ + @callback ws_mask(payload :: binary(), mask :: integer()) :: binary() + + @behaviour __MODULE__ # Note that masking is an involution, so we don't need a separate unmask function - @impl Bandit.PrimitiveOps + @impl true def ws_mask(payload, mask) when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do ws_mask(<<>>, payload, mask) diff --git a/lib/bandit/websocket/handler.ex b/lib/bandit/websocket/handler.ex index 919aed3c..7daaa401 100644 --- a/lib/bandit/websocket/handler.ex +++ b/lib/bandit/websocket/handler.ex @@ -18,7 +18,7 @@ defmodule Bandit.WebSocket.Handler do connection_opts = state.opts.websocket |> Keyword.merge(connection_opts) - |> Keyword.put(:primitive_ops_module, Keyword.get(state.opts.http_1, :primitive_ops_module)) + |> Keyword.put_new(:primitive_ops_module, Bandit.PrimitiveOps.WebSocket) state = state diff --git a/test/bandit/websocket/frame_deserialization_test.exs b/test/bandit/websocket/frame_deserialization_test.exs index 9daabc42..56b518de 100644 --- a/test/bandit/websocket/frame_deserialization_test.exs +++ b/test/bandit/websocket/frame_deserialization_test.exs @@ -1,23 +1,23 @@ defmodule WebSocketFrameDeserializationTest do use ExUnit.Case, async: true - import Bandit.PrimitiveOps.Default, only: [ws_mask: 2] + import Bandit.PrimitiveOps.WebSocket, only: [ws_mask: 2] - alias Bandit.PrimitiveOps.Default, as: DefaultPrimitiveOps + alias Bandit.PrimitiveOps.WebSocket, as: WebSocketPrimitiveOps alias Bandit.WebSocket.Frame describe "reserved flag parsing" do test "errors on reserved flag 1 being set" do frame = <<0x1::1, 0x1::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Received unsupported RSV flags 1"} end test "errors on reserved flag 2 being set" do frame = <<0x1::1, 0x2::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Received unsupported RSV flags 2"} end end @@ -29,7 +29,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 2::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -39,7 +39,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 10::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -49,7 +49,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -59,7 +59,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 126::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -69,7 +69,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 127::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -80,7 +80,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -91,7 +91,7 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end @@ -132,7 +132,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, :deserialization_failed} end end @@ -140,7 +140,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 1::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, :deserialization_failed} end end @@ -148,7 +148,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns an Unknown frame" do frame = <<0x1::1, 0x0::3, 0xF::4, 1::1, 1::7, 0::32, 1>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:error, "unknown opcode #{15}"} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "unknown opcode #{15}"} end end @@ -158,7 +158,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Continuation{fin: true, data: <<1, 2, 3, 4, 5>>}} end @@ -167,7 +167,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Continuation{fin: false, data: <<1, 2, 3, 4, 5>>}} end @@ -176,7 +176,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x0::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed continuation frame (RFC7692§6.1)"} end end @@ -187,7 +187,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end @@ -196,7 +196,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end @@ -205,7 +205,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x1::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -216,7 +216,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end @@ -225,7 +225,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end @@ -234,7 +234,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x0::1, 0x4::3, 0x2::4, 1::1, 5::7, 0x01020304::32, ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -247,7 +247,7 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x8::4, 1::1, 125::7, 0x01020304::32, ws_mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.ConnectionClose{code: 1000, reason: payload}} end @@ -256,20 +256,20 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x8::4, 1::1, 2::7, 0x01020304::32, ws_mask(<<1000::16>>, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.ConnectionClose{code: 1000}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.ConnectionClose{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.ConnectionClose{}} end test "refuses frame with invalid payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 1::7, 0x01020304::32, 1>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end @@ -277,21 +277,21 @@ defmodule WebSocketFrameDeserializationTest do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented connection close frame (RFC6455§5.5)"} end test "refuses frame with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed connection close frame (RFC7692§6.1)"} end end @@ -304,35 +304,35 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0x9::4, 1::1, 125::7, 0x01020304::32, ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Ping{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Ping{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Ping{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid ping payload (RFC6455§5.5.2)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented ping frame (RFC6455§5.5.2)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed ping frame (RFC7692§6.1)"} end end @@ -345,35 +345,35 @@ defmodule WebSocketFrameDeserializationTest do <<0x1::1, 0x0::3, 0xA::4, 1::1, 125::7, 0x01020304::32, ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Pong{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == {:ok, %Frame.Pong{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Pong{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid pong payload (RFC6455§5.5.3)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented pong frame (RFC6455§5.5.3)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame, DefaultPrimitiveOps) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed pong frame (RFC7692§6.1)"} end end diff --git a/test/support/simple_websocket_client.ex b/test/support/simple_websocket_client.ex index faea81ba..c15c23bb 100644 --- a/test/support/simple_websocket_client.ex +++ b/test/support/simple_websocket_client.ex @@ -1,7 +1,7 @@ defmodule SimpleWebSocketClient do @moduledoc false - alias Bandit.PrimitiveOps + alias Bandit.PrimitiveOps.WebSocket, as: WebSocketPrimitiveOps defdelegate tcp_client(context), to: Transport @@ -140,7 +140,7 @@ defmodule SimpleWebSocketClient do defp send_frame(client, flags, opcode, data) do mask = :rand.uniform(1_000_000) - masked_data = PrimitiveOps.Default.ws_mask(data, mask) + masked_data = WebSocketPrimitiveOps.ws_mask(data, mask) mask_flag_and_size = case byte_size(masked_data) do From 028e89071b733d7955fbdec8a88b13b8120ff8f9 Mon Sep 17 00:00:00 2001 From: Alisina Bahadori Date: Fri, 15 Nov 2024 10:53:34 -0500 Subject: [PATCH 3/3] Move primitive_ops_module in Extractor.new/3 --- lib/bandit/extractor.ex | 5 ++--- lib/bandit/websocket/handler.ex | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/lib/bandit/extractor.ex b/lib/bandit/extractor.ex index c470eb67..004dc03e 100644 --- a/lib/bandit/extractor.ex +++ b/lib/bandit/extractor.ex @@ -31,10 +31,9 @@ defmodule Bandit.Extractor do frame_parser: nil, primitive_ops_module: nil - @spec new(module(), Keyword.t()) :: t() - def new(frame_parser, opts) do + @spec new(module(), module(), Keyword.t()) :: t() + def new(frame_parser, primitive_ops_module, opts) do max_frame_size = Keyword.get(opts, :max_frame_size, 0) - primitive_ops_module = Keyword.fetch!(opts, :primitive_ops_module) %__MODULE__{ max_frame_size: max_frame_size, diff --git a/lib/bandit/websocket/handler.ex b/lib/bandit/websocket/handler.ex index 7daaa401..88e39c42 100644 --- a/lib/bandit/websocket/handler.ex +++ b/lib/bandit/websocket/handler.ex @@ -15,15 +15,15 @@ defmodule Bandit.WebSocket.Handler do |> Keyword.take([:fullsweep_after, :max_heap_size]) |> Enum.each(fn {key, value} -> :erlang.process_flag(key, value) end) - connection_opts = - state.opts.websocket - |> Keyword.merge(connection_opts) - |> Keyword.put_new(:primitive_ops_module, Bandit.PrimitiveOps.WebSocket) + connection_opts = Keyword.merge(state.opts.websocket, connection_opts) + + primitive_ops_module = + Keyword.get(state.opts.websocket, :primitive_ops_module, Bandit.PrimitiveOps.WebSocket) state = state |> Map.take([:handler_module]) - |> Map.put(:extractor, Extractor.new(Frame, connection_opts)) + |> Map.put(:extractor, Extractor.new(Frame, primitive_ops_module, connection_opts)) case Connection.init(websock, websock_opts, connection_opts, socket) do {:continue, connection} ->