|
| 1 | +""" |
| 2 | +Supports encoding and decoding vectors using the same approach as the vec:base64-encode and vec:base64-decode |
| 3 | +functions supported by the MarkLogic server. |
| 4 | +""" |
| 5 | + |
1 | 6 | import base64 |
2 | 7 | import struct |
3 | 8 | from typing import List |
4 | 9 |
|
5 | 10 |
|
6 | | -class VectorUtil: |
| 11 | +def base64_encode(vector: List[float]) -> str: |
7 | 12 | """ |
8 | | - Supports encoding and decoding vectors using the same approach as the vec:base64-encode and vec:base64-decode |
9 | | - functions supported by the MarkLogic server. |
| 13 | + Encodes a list of floats as a base64 string compatible with MarkLogic's vec:base64-encode. |
10 | 14 | """ |
| 15 | + dimensions = len(vector) |
| 16 | + # version (int32, 0) + dimensions (int32) + floats (little-endian) |
| 17 | + buffer = struct.pack("<ii", 0, dimensions) + struct.pack( |
| 18 | + "<" + "f" * dimensions, *vector |
| 19 | + ) |
| 20 | + return base64.b64encode(buffer).decode("ascii") |
11 | 21 |
|
12 | | - @staticmethod |
13 | | - def base64_encode(vector: List[float]) -> str: |
14 | | - """ |
15 | | - Encodes a list of floats as a base64 string compatible with MarkLogic's vec:base64-encode. |
16 | | - """ |
17 | | - dimensions = len(vector) |
18 | | - # version (int32, 0) + dimensions (int32) + floats (little-endian) |
19 | | - buffer = struct.pack("<ii", 0, dimensions) + struct.pack( |
20 | | - "<" + "f" * dimensions, *vector |
21 | | - ) |
22 | | - return base64.b64encode(buffer).decode("ascii") |
23 | 22 |
|
24 | | - @staticmethod |
25 | | - def base64_decode(encoded_vector: str) -> List[float]: |
26 | | - """ |
27 | | - Decodes a base64 string to a list of floats compatible with MarkLogic's vec:base64-decode. |
28 | | - """ |
29 | | - buffer = base64.b64decode(encoded_vector) |
30 | | - if len(buffer) < 8: |
31 | | - raise ValueError( |
32 | | - "Buffer is too short to contain version and dimensions." |
33 | | - ) |
34 | | - version, dimensions = struct.unpack("<ii", buffer[:8]) |
35 | | - if version != 0: |
36 | | - raise ValueError(f"Unsupported vector version: {version}") |
37 | | - expected_length = 8 + 4 * dimensions |
38 | | - if len(buffer) < expected_length: |
39 | | - raise ValueError( |
40 | | - f"Buffer is too short for the specified dimensions: expected {expected_length}, got {len(buffer)}" |
41 | | - ) |
42 | | - floats = struct.unpack( |
43 | | - "<" + "f" * dimensions, buffer[8 : 8 + 4 * dimensions] |
| 23 | +def base64_decode(encoded_vector: str) -> List[float]: |
| 24 | + """ |
| 25 | + Decodes a base64 string to a list of floats compatible with MarkLogic's vec:base64-decode. |
| 26 | + """ |
| 27 | + buffer = base64.b64decode(encoded_vector) |
| 28 | + if len(buffer) < 8: |
| 29 | + raise ValueError( |
| 30 | + "Buffer is too short to contain version and dimensions." |
| 31 | + ) |
| 32 | + version, dimensions = struct.unpack("<ii", buffer[:8]) |
| 33 | + if version != 0: |
| 34 | + raise ValueError(f"Unsupported vector version: {version}") |
| 35 | + expected_length = 8 + 4 * dimensions |
| 36 | + if len(buffer) < expected_length: |
| 37 | + raise ValueError( |
| 38 | + f"Buffer is too short for the specified dimensions: expected {expected_length}, got {len(buffer)}" |
44 | 39 | ) |
45 | | - return list(floats) |
| 40 | + floats = struct.unpack( |
| 41 | + "<" + "f" * dimensions, buffer[8 : 8 + 4 * dimensions] |
| 42 | + ) |
| 43 | + return list(floats) |
0 commit comments