diff --git a/lib/bandit.ex b/lib/bandit.ex index 36d99221..992d2f19 100644 --- a/lib/bandit.ex +++ b/lib/bandit.ex @@ -222,7 +222,7 @@ defmodule Bandit do @http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a @http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a @http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a - @websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a + @websocket_keys ~w(enabled max_frame_size validate_text_frames compress primitive_ops_module)a @thousand_island_keys ThousandIsland.ServerConfig.__struct__() |> Map.from_struct() |> Map.keys() diff --git a/lib/bandit/extractor.ex b/lib/bandit/extractor.ex index 2ee206ec..004dc03e 100644 --- a/lib/bandit/extractor.ex +++ b/lib/bandit/extractor.ex @@ -9,7 +9,7 @@ defmodule Bandit.Extractor do | {:error, term()} | :more - @callback deserialize(binary()) :: deserialize_result() + @callback deserialize(binary(), primitive_ops_module :: module()) :: deserialize_result() @type t :: %__MODULE__{ header: binary(), @@ -18,7 +18,8 @@ defmodule Bandit.Extractor do required_length: non_neg_integer(), mode: :header_parsing | :payload_parsing, max_frame_size: non_neg_integer(), - frame_parser: atom() + frame_parser: atom(), + primitive_ops_module: module() } defstruct header: <<>>, @@ -27,15 +28,17 @@ defmodule Bandit.Extractor do required_length: 0, mode: :header_parsing, max_frame_size: 0, - frame_parser: nil + frame_parser: nil, + primitive_ops_module: nil - @spec new(module(), Keyword.t()) :: t() - def new(frame_parser, opts) do + @spec new(module(), module(), Keyword.t()) :: t() + def new(frame_parser, primitive_ops_module, opts) do max_frame_size = Keyword.get(opts, :max_frame_size, 0) %__MODULE__{ max_frame_size: max_frame_size, - frame_parser: frame_parser + frame_parser: frame_parser, + primitive_ops_module: primitive_ops_module } end @@ -79,7 +82,7 @@ defmodule Bandit.Extractor do <> = IO.iodata_to_binary(state.payload) - frame = state.frame_parser.deserialize(state.header <> payload) + frame = state.frame_parser.deserialize(state.header <> payload, state.primitive_ops_module) state = transition_to_header_parsing(state, rest) {state, frame} diff --git a/lib/bandit/primitive_ops/websocket.ex b/lib/bandit/primitive_ops/websocket.ex new file mode 100644 index 00000000..d9af1672 --- /dev/null +++ b/lib/bandit/primitive_ops/websocket.ex @@ -0,0 +1,34 @@ +defmodule Bandit.PrimitiveOps.WebSocket do + @moduledoc """ + WebSocket primitive operations behaviour and default implementation + """ + + @doc """ + WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3) + """ + @callback ws_mask(payload :: binary(), mask :: integer()) :: binary() + + @behaviour __MODULE__ + + # Note that masking is an involution, so we don't need a separate unmask function + @impl true + def ws_mask(payload, mask) + when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do + ws_mask(<<>>, payload, mask) + end + + defp ws_mask(acc, <>, mask) do + ws_mask(<>)>>, rest, mask) + end + + for size <- [24, 16, 8] do + defp ws_mask(acc, <>, mask) do + <> = <> + <>)>> + end + end + + defp ws_mask(acc, <<>>, _mask) do + acc + end +end diff --git a/lib/bandit/websocket/frame.ex b/lib/bandit/websocket/frame.ex index 8d9fbe46..1d5b42c7 100644 --- a/lib/bandit/websocket/frame.ex +++ b/lib/bandit/websocket/frame.ex @@ -73,29 +73,32 @@ defmodule Bandit.WebSocket.Frame do end @impl Bandit.Extractor - @spec deserialize(binary()) :: {:ok, frame()} | {:error, term()} + @spec deserialize(binary(), module()) :: {:ok, frame()} | {:error, term()} def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end def deserialize( <> + payload::binary-size(length)>>, + primitive_ops_module ) do - to_frame(fin, compressed, rsv, opcode, mask, payload) + to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module) end - def deserialize(_msg) do + def deserialize(_msg, _primitive_ops_module) do {:error, :deserialization_failed} end @@ -155,14 +158,15 @@ defmodule Bandit.WebSocket.Frame do end end - defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do + defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, _primitive_ops_module) + when rsv != 0x0 do {:error, "Received unsupported RSV flags #{rsv}"} end - defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do + defp to_frame(fin, compressed, 0x0, opcode, mask, payload, primitive_ops_module) do fin = fin == 0x1 compressed = compressed == 0x1 - unmasked_payload = mask(payload, mask) + unmasked_payload = primitive_ops_module.ws_mask(payload, mask) opcode |> case do @@ -198,26 +202,4 @@ defmodule Bandit.WebSocket.Frame do defp mask_and_length(length) when length <= 125, do: <<0::1, length::7>> defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>> defp mask_and_length(length), do: <<0::1, 127::7, length::64>> - - # Note that masking is an involution, so we don't need a separate unmask function - @spec mask(binary(), integer()) :: binary() - def mask(payload, mask) - when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do - mask(<<>>, payload, mask) - end - - defp mask(acc, <>, mask) do - mask(<>)>>, rest, mask) - end - - for size <- [24, 16, 8] do - defp mask(acc, <>, mask) do - <> = <> - <>)>> - end - end - - defp mask(acc, <<>>, _mask) do - acc - end end diff --git a/lib/bandit/websocket/handler.ex b/lib/bandit/websocket/handler.ex index 3d69c0fd..88e39c42 100644 --- a/lib/bandit/websocket/handler.ex +++ b/lib/bandit/websocket/handler.ex @@ -17,10 +17,13 @@ defmodule Bandit.WebSocket.Handler do connection_opts = Keyword.merge(state.opts.websocket, connection_opts) + primitive_ops_module = + Keyword.get(state.opts.websocket, :primitive_ops_module, Bandit.PrimitiveOps.WebSocket) + state = state |> Map.take([:handler_module]) - |> Map.put(:extractor, Extractor.new(Frame, connection_opts)) + |> Map.put(:extractor, Extractor.new(Frame, primitive_ops_module, connection_opts)) case Connection.init(websock, websock_opts, connection_opts, socket) do {:continue, connection} -> diff --git a/test/bandit/websocket/frame_deserialization_test.exs b/test/bandit/websocket/frame_deserialization_test.exs index e2be97a5..56b518de 100644 --- a/test/bandit/websocket/frame_deserialization_test.exs +++ b/test/bandit/websocket/frame_deserialization_test.exs @@ -1,100 +1,103 @@ defmodule WebSocketFrameDeserializationTest do use ExUnit.Case, async: true - import Bandit.WebSocket.Frame, only: [mask: 2] + import Bandit.PrimitiveOps.WebSocket, only: [ws_mask: 2] + alias Bandit.PrimitiveOps.WebSocket, as: WebSocketPrimitiveOps alias Bandit.WebSocket.Frame describe "reserved flag parsing" do test "errors on reserved flag 1 being set" do frame = <<0x1::1, 0x1::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 1"} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == + {:error, "Received unsupported RSV flags 1"} end test "errors on reserved flag 2 being set" do frame = <<0x1::1, 0x2::3, 0x1::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:error, "Received unsupported RSV flags 2"} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == + {:error, "Received unsupported RSV flags 2"} end end describe "frame size" do test "parses 2 byte frames" do payload = String.duplicate("a", 2) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 2::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses 10 byte frames" do payload = String.duplicate("a", 10) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 10::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames up to 125 bytes" do payload = String.duplicate("a", 125) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 126 bytes long" do payload = String.duplicate("a", 126) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 126::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 127 bytes long" do payload = String.duplicate("a", 127) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 127::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 16_000 bytes long" do payload = String.duplicate("a", 16_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "parses frames 1_000_000 bytes long" do payload = String.duplicate("a", 1_000_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: payload}} end test "errors on frames over max_frame_size bytes with small frames" do payload = String.duplicate("a", 125) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 1234::32, masked_payload::binary>> @@ -104,7 +107,7 @@ defmodule WebSocketFrameDeserializationTest do test "errors on frames over max_frame_size bytes with medium frames" do payload = String.duplicate("a", 16_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 126::7, 16_000::16, 1234::32, masked_payload::binary>> @@ -115,7 +118,7 @@ defmodule WebSocketFrameDeserializationTest do test "errors on frames over max_frame_size bytes with large frames" do payload = String.duplicate("a", 1_000_000) - masked_payload = Frame.mask(payload, 1234) + masked_payload = ws_mask(payload, 1234) frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 127::7, 1_000_000::64, 1234::32, masked_payload::binary>> @@ -129,7 +132,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 125::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, :deserialization_failed} end end @@ -137,7 +140,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns error" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 1::7, 0::32, 1, 2, 3>> - assert Frame.deserialize(frame) == {:error, :deserialization_failed} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, :deserialization_failed} end end @@ -145,7 +148,7 @@ defmodule WebSocketFrameDeserializationTest do test "returns an Unknown frame" do frame = <<0x1::1, 0x0::3, 0xF::4, 1::1, 1::7, 0::32, 1>> - assert Frame.deserialize(frame) == {:error, "unknown opcode #{15}"} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "unknown opcode #{15}"} end end @@ -153,27 +156,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Continuation{fin: true, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Continuation{fin: false, data: <<1, 2, 3, 4, 5>>}} end test "refuses frame with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x0::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed continuation frame (RFC7692§6.1)"} end end @@ -182,27 +185,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin and per-message compressed bits clear" do frame = <<0x0::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x1::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Text{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -211,27 +214,27 @@ defmodule WebSocketFrameDeserializationTest do test "deserializes frames with fin and per-message compressed bits clear" do frame = <<0x0::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with fin bit set" do frame = <<0x1::1, 0x0::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: true, compressed: false, data: <<1, 2, 3, 4, 5>>}} end test "deserializes frames with per-message compressed bit set" do frame = <<0x0::1, 0x4::3, 0x2::4, 1::1, 5::7, 0x01020304::32, - mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> + ws_mask(<<1, 2, 3, 4, 5>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Binary{fin: false, compressed: true, data: <<1, 2, 3, 4, 5>>}} end end @@ -242,30 +245,31 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 125::7, 0x01020304::32, - mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> + ws_mask(<<1000::16, payload::binary>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.ConnectionClose{code: 1000, reason: payload}} end test "deserializes frames with code" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 2::7, 0x01020304::32, - mask(<<1000::16>>, 0x01020304)::binary>> + ws_mask(<<1000::16>>, 0x01020304)::binary>> - assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{code: 1000}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == + {:ok, %Frame.ConnectionClose{code: 1000}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.ConnectionClose{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.ConnectionClose{}} end test "refuses frame with invalid payload" do frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 1::7, 0x01020304::32, 1>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end @@ -273,21 +277,21 @@ defmodule WebSocketFrameDeserializationTest do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x8::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid connection close payload (RFC6455§5.5)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented connection close frame (RFC6455§5.5)"} end test "refuses frame with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x8::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed connection close frame (RFC7692§6.1)"} end end @@ -298,37 +302,37 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 125::7, 0x01020304::32, - mask(payload, 0x01020304)::binary>> + ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Ping{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.Ping{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Ping{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0x9::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid ping payload (RFC6455§5.5.2)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented ping frame (RFC6455§5.5.2)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0x9::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed ping frame (RFC7692§6.1)"} end end @@ -339,37 +343,37 @@ defmodule WebSocketFrameDeserializationTest do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 125::7, 0x01020304::32, - mask(payload, 0x01020304)::binary>> + ws_mask(payload, 0x01020304)::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Pong{data: payload}} end test "deserializes frames with no payload" do frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == {:ok, %Frame.Pong{}} + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:ok, %Frame.Pong{}} end test "refuses frame with overly large payload" do payload = String.duplicate("a", 126) frame = <<0x1::1, 0x0::3, 0xA::4, 1::1, 126::7, 126::16, 0x01020304::32, payload::binary>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Invalid pong payload (RFC6455§5.5.3)"} end test "refuses frames with fin bit clear" do frame = <<0x0::1, 0x0::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a fragmented pong frame (RFC6455§5.5.3)"} end test "refuses frames with per-message compressed bit set" do frame = <<0x1::1, 0x4::3, 0xA::4, 1::1, 0::7, 0x01020304::32>> - assert Frame.deserialize(frame) == + assert Frame.deserialize(frame, WebSocketPrimitiveOps) == {:error, "Cannot have a compressed pong frame (RFC7692§6.1)"} end end diff --git a/test/support/simple_websocket_client.ex b/test/support/simple_websocket_client.ex index 07938dbe..c15c23bb 100644 --- a/test/support/simple_websocket_client.ex +++ b/test/support/simple_websocket_client.ex @@ -1,7 +1,7 @@ defmodule SimpleWebSocketClient do @moduledoc false - alias Bandit.WebSocket.Frame + alias Bandit.PrimitiveOps.WebSocket, as: WebSocketPrimitiveOps defdelegate tcp_client(context), to: Transport @@ -140,7 +140,7 @@ defmodule SimpleWebSocketClient do defp send_frame(client, flags, opcode, data) do mask = :rand.uniform(1_000_000) - masked_data = Frame.mask(data, mask) + masked_data = WebSocketPrimitiveOps.ws_mask(data, mask) mask_flag_and_size = case byte_size(masked_data) do