diff --git a/.devops/openvino.Dockerfile b/.devops/openvino.Dockerfile new file mode 100644 index 0000000000000..16924e3937c90 --- /dev/null +++ b/.devops/openvino.Dockerfile @@ -0,0 +1,134 @@ +ARG OPENVINO_VERSION_MAJOR=2025.2 +ARG OPENVINO_VERSION_FULL=2025.2.0.19140.c01cd93e24d +ARG UBUNTU_VERSION=24.04 + +# Optional proxy build arguments - empty by default +ARG http_proxy= +ARG https_proxy= + +## Build Image +FROM ubuntu:${UBUNTU_VERSION} AS build + +# Pass proxy args to build stage +ARG http_proxy +ARG https_proxy + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + gnupg \ + wget \ + git \ + cmake \ + ninja-build \ + build-essential \ + libtbb12 \ + libcurl4-openssl-dev && \ + rm -rf /var/lib/apt/lists/* + +# Install OpenVINO for Ubuntu 24.04 +ARG OPENVINO_VERSION_MAJOR +ARG OPENVINO_VERSION_FULL +RUN mkdir -p /opt/intel && \ + wget https://storage.openvinotoolkit.org/repositories/openvino/packages/${OPENVINO_VERSION_MAJOR}/linux/openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64.tgz && \ + tar -xf openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64.tgz && \ + mv openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64 /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} && \ + cd /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} && \ + echo "Y" | ./install_dependencies/install_openvino_dependencies.sh && \ + cd - && \ + ln -s /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} /opt/intel/openvino + +ENV OpenVINO_DIR=/opt/intel/openvino + +WORKDIR /app + +COPY . . + +# Build Stage +RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \ + cmake -B build/ReleaseOV -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENVINO=ON && \ + cmake --build build/ReleaseOV -j$(nproc)" + +# Copy all necessary libraries +RUN mkdir -p /app/lib && \ + find build/ReleaseOV -name '*.so*' -exec cp {} /app/lib \; && \ + find ${OpenVINO_DIR}/runtime/lib/intel64 -name '*.so*' -exec cp -P {} /app/lib \; 2>/dev/null || \ + find ${OpenVINO_DIR}/lib/intel64 -name '*.so*' -exec cp -P {} /app/lib \; + +# Create runtime directories and copy binaries +RUN mkdir -p /app/full \ + && cp build/ReleaseOV/bin/* /app/full/ \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base Runtime Image +FROM ubuntu:${UBUNTU_VERSION} AS base + +# Pass proxy args to runtime stage +ARG http_proxy +ARG https_proxy + +RUN apt-get update \ + && apt-get install -y libgomp1 libtbb12 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app/ + +### Full (all binaries) +FROM base AS full + +ARG http_proxy +ARG https_proxy + +COPY --from=build /app/full /app/ + +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + python3 \ + python3-venv \ + python3-pip && \ + python3 -m venv /ov-venv && \ + /ov-venv/bin/pip install --no-cache-dir --upgrade pip setuptools wheel && \ + /ov-venv/bin/pip install --no-cache-dir -r requirements.txt && \ + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /tmp/* /var/tmp/* && \ + find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete && \ + find /var/cache -type f -delete + +ENTRYPOINT ["/bin/bash", "-c", "source /ov-venv/bin/activate && exec /app/tools.sh \"$@\"", "--"] + + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app/ + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app/ + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 424b4ba786610..7892591dd2644 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -629,6 +629,45 @@ jobs: -DGGML_SYCL_F16=ON cmake --build build --config Release -j $(nproc) + ubuntu-24-cmake-openvino: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-24-cmake-openvino-no-preset-v1 + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + export OPENVINO_VERSION_MAJOR=2025.2 + export OPENVINO_VERSION_FULL=2025.2.0.19140.c01cd93e24d + sudo apt-get update + sudo apt-get install -y build-essential libcurl4-openssl-dev libtbb12 cmake ninja-build python3-pip curl wget tar + sudo mkdir -p /opt/intel + wget -O openvino_${OPENVINO_VERSION_MAJOR}.tgz https://storage.openvinotoolkit.org/repositories/openvino/packages/${OPENVINO_VERSION_MAJOR}/linux/openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64.tgz + tar -xf openvino_${OPENVINO_VERSION_MAJOR}.tgz + sudo mv openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64 /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} + rm openvino_${OPENVINO_VERSION_MAJOR}.tgz + cd /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} + echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh && cd - + sudo ln -s /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} /opt/intel/openvino + + - name: Build + id: cmake_build + run: | + source /opt/intel/openvino/setupvars.sh + cmake -B build/ReleaseOV -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENVINO=ON + cmake --build build/ReleaseOV --config Release -j $(nproc) + build-linux-cross: uses: ./.github/workflows/build-linux-cross.yml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index bf2c8509ec14e..410562812671b 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -45,6 +45,7 @@ jobs: - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false } + - { tag: "openvino", dockerfile: ".devops/openvino.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true } steps: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f461456edf008..93d8e5e6d8dba 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -241,6 +241,63 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip name: llama-bin-ubuntu-vulkan-x64.zip + ubuntu-24-openvino: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-24-cmake-openvino-release-no-preset-v1 + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + export OPENVINO_VERSION_MAJOR=2025.2 + export OPENVINO_VERSION_FULL=2025.2.0.19140.c01cd93e24d + sudo apt-get update + sudo apt-get install -y build-essential libcurl4-openssl-dev libtbb12 cmake ninja-build python3-pip curl wget tar + sudo mkdir -p /opt/intel + wget -O openvino_${OPENVINO_VERSION_MAJOR}.tgz https://storage.openvinotoolkit.org/repositories/openvino/packages/${OPENVINO_VERSION_MAJOR}/linux/openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64.tgz + tar -xf openvino_${OPENVINO_VERSION_MAJOR}.tgz + sudo mv openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64 /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} + rm openvino_${OPENVINO_VERSION_MAJOR}.tgz + cd /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} + echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh && cd - + sudo ln -s /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} /opt/intel/openvino + + - name: Build + id: cmake_build + run: | + source /opt/intel/openvino/setupvars.sh + cmake -B build/ReleaseOV -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENVINO=ON + cmake --build build/ReleaseOV --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/ReleaseOV/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-x64.zip ./build/ReleaseOV/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-x64.zip + name: llama-bin-ubuntu-openvino-x64.zip + windows-cpu: runs-on: windows-2025 diff --git a/IR.xml b/IR.xml new file mode 100644 index 0000000000000..f5b1df8740a66 --- /dev/null +++ b/IR.xml @@ -0,0 +1,462 @@ + + + + + + + + 2 + 128 + 64 + + + + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 2 + + + + + + + + + + + + + + 2 + 128 + 64 + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 32 + + + 1 + 1 + 32 + + + + + 1 + 1 + 32 + + + + + + + + 1 + 1 + 2 + + + + + 1 + 1 + 2 + + + + + + + + 3 + + + + + + + 1 + 1 + 2 + + + 3 + + + + + 2 + 1 + 1 + + + + + + + + 1 + 1 + 32 + + + 2 + 1 + 1 + + + + + 2 + 1 + 32 + + + + + + + + 1 + + + + + + + + 2 + 1 + 32 + + + 1 + + + + + 2 + 1 + 32 + + + + + + + 2 + 1 + 32 + + + + + 2 + 1 + 32 + + + + + + + + + + + + + + 2 + 1 + 32 + + + + + + 2 + 1 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + 2 + 1 + 32 + + + + + 2 + 1 + 32 + + + + + + + + 2 + 1 + 32 + + + + + + 2 + 1 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 1 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 32 + + + + + + + + 2 + 128 + 32 + + + 2 + 128 + 32 + + + + + 2 + 128 + 64 + + + + + + + 2 + 128 + 64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ci/run.sh b/ci/run.sh index 68cbfdf2f52aa..d7b7c27ee4d84 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -22,6 +22,9 @@ # # with MUSA support # GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with OPENVINO support +# GG_BUILD_OPENVINO=1 GG_BUILD_LOW_PERF=1 GGML_OPENVINO_DEVICE=CPU bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -114,6 +117,15 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then # arm 9 and newer enables sve by default, adjust these flags depending on the cpu used CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" fi + +if [ ! -z ${GG_BUILD_OPENVINO} ]; then + if [ -z ${OpenVINO_DIR} ]; then + echo "OpenVINO_DIR not found, please install OpenVINO via archives and enable it by:" + echo "source /opt/intel/openvino/setupvars.sh" + exit 1 + fi + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF" +fi ## helpers # download a file if it does not exist or if it is outdated diff --git a/docs/build.md b/docs/build.md index dcbcce7549ad2..e2ef8b4e08b5b 100644 --- a/docs/build.md +++ b/docs/build.md @@ -13,6 +13,21 @@ cd llama.cpp The following sections describe how to build with different backends and options. +* [CPU Build](#cpu-build) +* [BLAS Build](#blas-build) +* [Metal Build](#metal-build) +* [SYCL](#sycl) +* [CUDA](#cuda) +* [MUSA](#musa) +* [HIP](#hip) +* [Vulkan](#vulkan) +* [CANN](#cann) +* [Arm® KleidiAI™](#arm-kleidiai) +* [OpenCL](#opencl) +* [Android](#android-1) +* [OpenVINO](#openvino) +* [Notes about GPU-accelerated backends](#notes-about-gpu-accelerated-backends) + ## CPU Build Build llama.cpp using `CMake`: @@ -575,6 +590,127 @@ Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/m To read documentation for how to build on IBM Z & LinuxONE, [click here](./build-s390x.md) +## OpenVINO + +[OpenVINO](https://docs.openvino.ai/2025/index.html) is an open-source toolkit for optimizing and deploying high-performance AI inference, specifically designed for Intel hardware, including CPUs, GPUs, and NPUs, in the cloud, on-premises, and on the edge. +The OpenVINO backend enhances performance by leveraging hardware-specific optimizations and can be enabled for use with llama.cpp. + +Follow the instructions below to install OpenVINO runtime and build llama.cpp with OpenVINO support. + +### Prerequisites + +- Linux or Windows system with Intel hardware (CPU, GPU, or NPU) +- **For Intel GPU or NPU Usage**: Install the appropriate hardware drivers for your Intel GPU or NPU. For detailed instructions, see: [Additional Configurations for Hardware Acceleration](https://docs.openvino.ai/2025/get-started/install-openvino/configurations.html). +- Git, CMake, and Ninja software tools are needed for building. +```bash + sudo apt-get update + sudo apt-get install -y build-essential libcurl4-openssl-dev libtbb12 cmake ninja-build python3-pip curl wget tar +``` + +### 1. Install OpenVINO Runtime + +- Follow the guide to install OpenVINO Runtime from an archive file: [Linux](https://docs.openvino.ai/2025/get-started/install-openvino/install-openvino-archive-linux.html) | [Windows](https://docs.openvino.ai/2025/get-started/install-openvino/install-openvino-archive-windows.html) + +
+📦 Click to expand OpenVINO 2025.2 installation commands on Linux +
+ +```bash +export OPENVINO_VERSION_MAJOR=2025.2 +export OPENVINO_VERSION_FULL=2025.2.0.19140.c01cd93e24d +sudo apt-get update +sudo apt-get install -y build-essential libcurl4-openssl-dev libtbb12 cmake ninja-build python3-pip curl wget tar +sudo mkdir -p /opt/intel +wget -O openvino_${OPENVINO_VERSION_MAJOR}.tgz https://storage.openvinotoolkit.org/repositories/openvino/packages/${OPENVINO_VERSION_MAJOR}/linux/openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64.tgz +tar -xf openvino_${OPENVINO_VERSION_MAJOR}.tgz +sudo mv openvino_toolkit_ubuntu24_${OPENVINO_VERSION_FULL}_x86_64 /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} +rm openvino_${OPENVINO_VERSION_MAJOR}.tgz +cd /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} +echo "Y" | sudo -E ./install_dependencies/install_openvino_dependencies.sh && cd - +sudo ln -s /opt/intel/openvino_${OPENVINO_VERSION_MAJOR} /opt/intel/openvino +source /opt/intel/openvino/setupvars.sh +``` +
+ +- Verify OpenVINO is initialized properly +```bash +echo $OpenVINO_DIR +``` + +### 2. Build llama.cpp with OpenVINO Backend + +Clone the OpenVINO-enabled llama.cpp fork and build it: + +```bash +git clone https://github.com/ravi9/llama.cpp.git +cd llama.cpp +git switch dev_backend_openvino + +# Build with OpenVINO support +source /opt/intel/openvino/setupvars.sh +cmake -B build/ReleaseOV -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_OPENVINO=ON -DGGML_CPU_REPACK=OFF +cmake --build build/ReleaseOV --config Release -j $(nproc) +``` + +### 3. Download Sample Model + +Download models for testing: + +```bash +# Create models directory +mkdir -p ~/models/ + +# Download model file: Llama-3.2-1B-Instruct.fp16.gguf +wget https://huggingface.co/MaziyarPanahi/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct.fp16.gguf \ + -O ~/models/Llama-3.2-1B-Instruct.fp16.gguf + +# Download model file: Phi-3-mini-4k-instruct-fp16.gguf +wget https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-fp16.gguf \ + -O ~/models/Phi-3-mini-4k-instruct-fp16.gguf +``` + +### 4. Run inference with OpenVINO backend: + +When using the OpenVINO backend, the first inference token may have slightly higher latency due to on-the-fly conversion to the OpenVINO graph. Subsequent tokens and runs will be faster. + +```bash +export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache +# Default device is GPU. +# If not set, automatically selects the first available device in priority order: GPU, CPU, NPU. +export GGML_OPENVINO_DEVICE=GPU + +./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " + +``` + +To run in chat mode: +```bash +export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache +./build/ReleaseOV/bin/llama-cli -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " + +``` + +### Configuration Options + +Control OpenVINO behavior using these environment variables: + +- **`GGML_OPENVINO_DEVICE`**: Specify the target device for OpenVINO inference. If not set, automatically selects the first available device in priority order: GPU, CPU, NPU. When set to `NPU` to use Intel NPUs, it enables static compilation mode for optimal performance. +- **`GGML_OPENVINO_CACHE_DIR`**: Directory for model caching (recommended: `/tmp/ov_cache`). If set, enables model caching in OpenVINO. Note: Not supported when using NPU devices yet. +- **`GGML_OPENVINO_PROFILING`**: Enable execution time profiling. +- **`GGML_OPENVINO_DUMP_CGRAPH`**: Save compute graph to `cgraph.txt`. +- **`GGML_OPENVINO_DUMP_IR`**: Export OpenVINO IR files with timestamps. +- **`GGML_OPENVINO_DEBUG_INPUT`**: Enable input debugging. +- **`GGML_OPENVINO_DEBUG_OUTPUT`**: Enable output debugging. + +### Example with Profiling + +```bash +export GGML_OPENVINO_CACHE_DIR=/tmp/ov_cache +export GGML_OPENVINO_PROFILING=1 + +./build/ReleaseOV/bin/llama-simple -m ~/models/Llama-3.2-1B-Instruct.fp16.gguf -n 50 "The story of AI is " +``` + ## Notes about GPU-accelerated backends The GPU may still be used to accelerate some parts of the computation even when using the `-ngl 0` option. You can fully disable GPU acceleration by using `--device none`. diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 1a0fdb676c449..5c29df642cfb0 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -243,6 +243,8 @@ set (GGML_SYCL_TARGET "INTEL" CACHE STRING set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING "ggml: sycl device architecture") +option(GGML_OPENVINO "ggml: use OPENVINO" OFF) + option(GGML_OPENCL "ggml: use OpenCL" OFF) option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) @@ -314,6 +316,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-sycl.h include/ggml-vulkan.h include/ggml-webgpu.h + include/ggml-openvino.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h new file mode 100644 index 0000000000000..151c48d40d067 --- /dev/null +++ b/ggml/include/ggml-openvino.h @@ -0,0 +1,63 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_OPENVINO_NAME "OPENVINO" +#define GGML_OPENVINO_MAX_DEVICES 16 + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend); + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device); + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_split_buffer_type(const float * tensor_split); + +// pinned host buffer for use with the CPU backend for faster copies between CPU +// and GPU +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(void); + +GGML_BACKEND_API int ggml_backend_openvino_get_device_count(void); +// GGML_BACKEND_API void ggml_backend_openvino_get_device_description(int device, char * description, +// size_t description_size); +// GGML_BACKEND_API void ggml_backend_openvino_get_device_memory(int device, size_t * free, size_t * total); + +// GGML_BACKEND_API bool ggml_backend_openvino_register_host_buffer(void * buffer, size_t size); +// GGML_BACKEND_API void ggml_backend_openvino_unregister_host_buffer(void * buffer); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void); + +struct ggml_openvino_device_info { + int device_count; + + struct openvino_device_info { + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory + size_t total_vram; + }; + + openvino_device_info devices[GGML_OPENVINO_MAX_DEVICES] = {}; + + std::array default_tensor_split = {}; +}; + +const ggml_openvino_device_info & ggml_openvino_info(); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c8f3d8596427c..ca9e24313e996 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -387,6 +387,7 @@ ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) ggml_add_backend(zDNN) ggml_add_backend(OpenCL) +ggml_add_backend(OPENVINO) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7002cb07e0015..3d048cac3ffe5 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -69,6 +69,10 @@ #include "ggml-cann.h" #endif +#ifdef GGML_USE_OPENVINO +#include "ggml-openvino.h" +#endif + // disable C++17 deprecation warning for std::codecvt_utf8 #if defined(__clang__) # pragma clang diagnostic push @@ -199,6 +203,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif +#ifdef GGML_USE_OPENVINO + register_backend(ggml_backend_openvino_reg()); +#endif #ifdef GGML_USE_CPU register_backend(ggml_backend_cpu_reg()); #endif @@ -590,6 +597,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("vulkan", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); + ggml_backend_load_best("openvino", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend const char * backend_path = std::getenv("GGML_BACKEND_PATH"); diff --git a/ggml/src/ggml-openvino/.clang-format b/ggml/src/ggml-openvino/.clang-format new file mode 100644 index 0000000000000..63dc2c472a95d --- /dev/null +++ b/ggml/src/ggml-openvino/.clang-format @@ -0,0 +1,143 @@ +--- +# Override root .clang-format +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +ReferenceAlignment: Left +PointerAlignment: Left +Cpp11BracedListStyle: true +AccessModifierOffset: -4 +BinPackArguments: false +BreakBeforeBraces: Attach +IndentCaseBlocks: false +IndentCaseLabels: false + +Language: Cpp +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveBitFields: AcrossComments +AlignConsecutiveMacros: AcrossComments +# AlignConsecutiveShortCaseStatements: AcrossComments +AlignEscapedNewlines: Left # LeftWithLastLine +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 1 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: true +BinPackParameters: true +BitFieldColonSpacing: Both +# BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakBeforeBinaryOperators: None +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +# BreakBinaryOperations: Never +BreakConstructorInitializers: AfterColon +# BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +# BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: Leave +EmptyLineAfterAccessModifier: Never +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + SortPriority: 0 + - Regex: '^<.*' + Priority: 2 + SortPriority: 0 + - Regex: '.*' + Priority: 3 + SortPriority: 0 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentExternBlock: NoIndent +IndentGotoLabels: false +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: true # NOTE: may lead to incorrect formatting +InsertNewlineAtEOF: true +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +LineEnding: LF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: CurrentLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +QualifierAlignment: Left +#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict'] +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' +ReflowComments: false # IndentOnly +SeparateDefinitionBlocks: Always +SortIncludes: CaseInsensitive +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: c++17 +TabWidth: 4 +UseTab: Never +WhitespaceSensitiveMacros: ['STRINGIZE'] +... diff --git a/ggml/src/ggml-openvino/CMakeLists.txt b/ggml/src/ggml-openvino/CMakeLists.txt new file mode 100644 index 0000000000000..216aa756a7a96 --- /dev/null +++ b/ggml/src/ggml-openvino/CMakeLists.txt @@ -0,0 +1,19 @@ +find_package(OpenVINO REQUIRED) + +file(GLOB_RECURSE GGML_HEADERS_OPENVINO "*.h" "*.hpp") +file(GLOB_RECURSE GGML_SOURCES_OPENVINO "*.cpp") + +ggml_add_backend_library(ggml-openvino + ${GGML_SOURCES_OPENVINO} + ${GGML_HEADERS_OPENVINO} +) + +target_link_libraries(ggml-openvino PRIVATE openvino::runtime) + +if (GGML_OPENVINO) + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + else() + message(FATAL_ERROR "OpenVINO: OpenVINO toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp new file mode 100644 index 0000000000000..751fa192a4261 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -0,0 +1,818 @@ +#include "ggml-decoder.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-quants.hpp" + +GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, + int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size, + const std::vector& swa_layers) : + m_cgraph(cgraph), + m_node(node), + m_op_name(std::string(node->name)), + m_context_size(context_size), + m_context_size_swa(context_size_swa), + m_swa_layers(swa_layers), + m_num_heads(num_heads), + m_num_heads_kv(num_heads_kv), + m_head_size(head_size), + m_is_static(is_static), + m_is_first_token(is_first_token) { + set_input_output(node); +} + +GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, + std::map>& model_weights, bool is_static, + bool is_first_token) : + m_cgraph(cgraph), + m_op_name(m_node ? std::string(m_node->name) : ""), + m_model_weights(model_weights), + m_is_static(is_static), + m_is_first_token(is_first_token) { + if (is_first_token && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) { + print_tensor_address_map(cgraph); + } + + set_llm_params(); + + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto* cur_node = cgraph->nodes[node_n]; + m_nodes.push_back(cur_node); + set_input_output(cur_node); + } + + add_extra_inputs(); +} + +GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph, + std::map>& model_weights) { + m_cgraph = cgraph; + m_model_weights = model_weights; + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto* cur_node = cgraph->nodes[node_n]; + if (cur_node->op == GGML_OP_NONE) { + continue; + } + m_nodes.push_back(cur_node); + set_input_output(cur_node, true); + } +} + +// Called in GgmlOvDecoder constructor. Two cases: 1. constructing a decoder for the whole graph; +// 2. constructing a decoder for a node; +// 3. constructing a decoder for the whole graph naively (op test case) +void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { + std::string node_name; + if (node->op == GGML_OP_SET_ROWS) { + // SET_ROWS updates the tensor in place. For later ov op that uses the + // the view_src of SET_ROWS, we need to make sure they get the updated tensor + // by putting the view_src name in the tensor_map in + // /src/frontends/ggml/src/translate_session.cpp + node_name = std::string(node->view_src->name); + } else { + node_name = std::string(node->name); + } + + m_output_names.push_back(node_name); + m_outputs[node_name] = node; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto* src = node->src[i]; + if (src == nullptr) { + continue; + } + std::string src_name = std::string(src->name); + m_input_names.push_back(src_name); + m_inputs[src_name] = src; + m_op_node_name.emplace_back(src_name, ggml_op_name(node->op)); + + // Add model inputs and weights constants, if called for the whole graph + if (naive) { + if (m_model_weights.find(src_name) == m_model_weights.end()) { + auto param_node = std::make_shared(get_ov_type(src), get_graph_input_shape(src)); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } + + } else if (!m_node && !src->view_src) { + ggml_backend_buffer* buffer = src->buffer; + + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) { + // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { + assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0); + } + if (m_model_inputs.find(src_name) != m_model_inputs.end()) { + continue; + } + auto param_node = std::make_shared(get_ov_type(src), get_graph_input_shape(src)); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } + } + } + + // Add model outputs, if called for the whole graph + if (naive) { + m_model_output_names.push_back(node_name); + } else if (!m_node) { + // Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches + static std::set debug_output_names = {}; + // Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph + if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT || node_name.find("result") == 0 || + debug_output_names.count(node_name)) { + if (node->op == GGML_OP_SET_ROWS) { + assert(node_name.find("cache_k") == 0 || node_name.find("cache_v") == 0); + if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), node_name); it == m_kv_names.end()) { + m_kv_names.push_back(node_name); + } + } + if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), node_name); + it == m_model_output_names.end()) { + m_model_output_names.push_back(node_name); + } + } + } + + if (m_node) { + switch (node->op) { + case GGML_OP_RESHAPE: { + if (node->src[0]->op == GGML_OP_RESHAPE && node->src[0]->src[0]->ne[0] == node->ne[0] && + node->src[0]->src[0]->ne[1] == node->ne[1]) { + m_op_case = 4; + } else if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) { + m_op_case = 1; + } else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) { + m_op_case = 2; + } else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) { + m_op_case = 3; + } + break; + } + case GGML_OP_CONT: { + if (node->src[0]->op == GGML_OP_PERMUTE) { + m_op_case = 1; + } else if (node->src[0]->op == GGML_OP_TRANSPOSE) { + m_op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW) { + // The input comes from a VIEW which is subtensor + m_op_case = 3; + } + break; + } + case GGML_OP_PERMUTE: { + if (node->src[0]->op != GGML_OP_VIEW) { + m_op_case = 1; + } else if (ggml_is_contiguous(node->src[0])) { + std::string src_name(node->view_src->name); + if (src_name.find("cache") == std::string::npos) { + m_op_case = 1; + } else { + // Permute kv cache (view) + int layer = extract_layer_from_name(src_name); + if (!is_swa_layer(layer)) { + m_op_case = 2; + } else { + m_op_case = 3; + } + } + } + break; + } + case GGML_OP_MUL_MAT: { + if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) { + m_op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) { + // test-backend-ops case + m_op_case = 3; + } + break; + } + case GGML_OP_GET_ROWS: { + if (node->src[1]->op == GGML_OP_VIEW) { + m_op_case = 2; + } + break; + } + case GGML_OP_ROPE: { + if (node->src[0]->op == GGML_OP_VIEW) { + m_op_case = 2; + } + break; + } + case GGML_OP_VIEW: { + if (node->src[0]->op == GGML_OP_VIEW) { + auto* src = node->src[0]; + auto* view_src = src->view_src; + if (view_src->ne[1] != src->ne[2]) { + throw std::runtime_error("Unsupported VIEW case"); + } + m_op_case = 2; + } + } + default: + break; + } + } +} + +int extract_layer_from_name(const std::string& name) { + size_t pos1 = name.find("_l"); + assert(pos1 != std::string::npos); + pos1 += 2; + size_t pos2 = name.find(' ', pos1); + if (pos2 == std::string::npos) { + pos2 = name.length(); + } + std::string layer_str = name.substr(pos1, pos2 - pos1); + int layer = std::stoi(layer_str); + return layer; +} + +void GgmlOvDecoder::set_llm_params() { + for (int i = 0; i < m_cgraph->n_nodes; i++) { + auto* node = m_cgraph->nodes[i]; + std::string name = std::string(node->name); + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + auto* cache_k = node->src[1]; + cache_k = cache_k->view_src ? cache_k->view_src : cache_k; + int layer = extract_layer_from_name(cache_k->name); + + if (std::string(node->src[3]->name).find("swa") != std::string::npos) { + m_swa_layers.push_back(layer); + m_context_size_swa = cache_k->ne[1]; + } else { + m_context_size = cache_k->ne[1]; + } + } else if (node->op == GGML_OP_ROPE && + (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) { + m_head_size = node->ne[0]; + m_num_heads = node->ne[1]; + m_rope_params = node->op_params; + } else if (node->op == GGML_OP_ROPE && + (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0)) { + m_num_heads_kv = node->ne[1]; + } + } +} + +ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) const { + auto name = std::string(src->name); + ov::PartialShape input_shape; + if (name == "inp_tokens" || name == "inp_pos") { + if (m_is_static) { + if (m_is_first_token) { + input_shape = ov::PartialShape{1, 1, m_context_size}; + } else { + input_shape = ov::PartialShape{1, 1, 1}; + } + } else { + input_shape = ov::PartialShape{1, 1, -1}; + } + } else if (name == "inp_out_ids" && !m_is_static) { + input_shape = ov::PartialShape{1, 1, -1}; + } else if (name.find("KQ_mask") == 0) { + if (m_is_static) { + if (m_is_first_token) { + input_shape = ov::PartialShape{1, m_context_size, m_context_size}; + } else { + input_shape = ov::PartialShape{1, 1, m_context_size}; + } + } else { + input_shape = ov::PartialShape{1, -1, -1}; + } + } else if (name.find("cache_") == 0) { + int layer = extract_layer_from_name(name); + bool is_swa = is_swa_layer(layer); + input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size}; + } else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) { + input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1}; + } else if (src->op == GGML_OP_VIEW) { + // This case is added to make test-backend-ops work + input_shape = ov::PartialShape{get_shape(src->view_src)}; + } else { + input_shape = ov::PartialShape{get_shape(src)}; + } + return input_shape; +} + +void GgmlOvDecoder::add_extra_inputs() { + // Extra inputs: + // 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned, + // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. + // Not used for NPU + int64_t attention_size = -1; + int64_t attention_size_swa = -1; + for (const auto& node : m_nodes) { + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + auto* mask = node->src[3]; + std::string mask_name(mask->name); + if (mask_name.find("KQ_mask") != 0) { + throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name)); + } + if (mask_name.find("swa") != std::string::npos) { + attention_size_swa = mask->ne[0]; + } else { + attention_size = mask->ne[0]; + } + } + } + + auto create_attention_size_input = [this](const std::string& name, int64_t size) { + auto param_node = std::make_shared(ov::element::i64, ov::Shape{1}); + param_node->set_friendly_name(name); + param_node->output(0).get_tensor().set_names({name}); + m_model_extra_inputs[name] = param_node; + + auto tensor = std::make_shared(ov::element::i64, ov::Shape{1}); + *tensor->data() = size; + m_model_extra_input_values[name] = tensor; + }; + + create_attention_size_input("attention_size", attention_size); + if (attention_size_swa != -1) { + create_attention_size_input("attention_size_swa", attention_size_swa); + } +} + +const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const { + if (tensor == nullptr) { + return nullptr; + } + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto* node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] == tensor) { + return node; + } + } + } + return nullptr; +} + +const ggml_tensor* GgmlOvDecoder::get_tensor_from_name(const std::string& name) const { + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto* node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + const auto* src = node->src[j]; + if (src == nullptr) { + break; + } + if (std::string(src->name) == name) { + return src; + } + } + } + return nullptr; +} + +std::map GgmlOvDecoder::get_kv_param_res_names() const { + std::map kv_param_res_names; + for (const auto& name : m_kv_names) { + if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { + kv_param_res_names[name] = name; + } + } + return kv_param_res_names; +} + +std::map> GgmlOvDecoder::create_weight_nodes( + struct ggml_cgraph* cgraph, std::map types_to_requantize) { + std::map> model_weights; + static std::mutex weights_mutex; + auto* nodes = cgraph->nodes; + auto n_nodes = cgraph->n_nodes; + std::for_each(std::execution::par, nodes, nodes + n_nodes, [&](ggml_tensor* node) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto* src = node->src[i]; + if (src == nullptr) { + continue; + } + + std::string src_name(src->name); + if (!src->view_src) { + ggml_backend_buffer* buffer = src->buffer; + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) { + bool should_create = false; + { + std::lock_guard lock(weights_mutex); + if (model_weights.find(src_name) == model_weights.end()) { + model_weights[src_name] = nullptr; + should_create = true; + } + } + if (should_create) { + auto requant_type = types_to_requantize.count(src->type) ? + std::optional(types_to_requantize.at(src->type)) : + std::nullopt; + auto weight_node = create_weight_node(src, requant_type); + weight_node->set_friendly_name(src_name); + { + std::lock_guard lock(weights_mutex); + model_weights[src_name] = weight_node; + } + } + } + } + } + }); + return model_weights; +} + +std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor* tensor, + std::optional requant_type) { + std::set weight_types = {GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K}; + if (weight_types.find(tensor->type) == weight_types.end()) { + throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + + ggml_type_name(tensor->type)); + } + + auto node_type = get_ov_type(tensor); + auto node_shape = get_shape(tensor); + auto ne_total = ggml_nelements(tensor); + + OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name); + node_shape.erase(node_shape.begin()); + + // F16 and F32 case + if (node_type != ov::element::dynamic) { + ov::Tensor weights(node_type, node_shape); + memcpy(weights.data(), tensor->data, ne_total * node_type.size()); + std::shared_ptr weight_node = std::make_shared(weights); + // Disabled because it triggers a bug in NPUW, no performance impact on CPU GPU + // if (node_type == ov::element::f16) { + // weight_node = std::make_shared(weight_node, ov::element::f32); + // } + weight_node->set_friendly_name(tensor->name); + return weight_node; + } + + // Quantized case + OPENVINO_ASSERT( + tensor->extra == nullptr, + "Unsupported weight tensor: " + std::string(tensor->name) + " Possibly this is a repacked quantized weights"); + + if (requant_type.has_value()) { + return requantize(tensor, requant_type.value()); + } + + ov::element::Type weight_type; + if (tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_1 || tensor->type == GGML_TYPE_Q4_K) { + weight_type = ov::element::u4; + } else { // tensor.type == GGUF_TYPE_Q8_0 || tensor.type == GGUF_TYPE_Q6_K || tensor.type == GGUF_TYPE_Q5_K + weight_type = ov::element::u8; + } + + uint64_t weights_per_block; + // here we only consider sub block, q6k:16 q4k:32 q5k:32 + if (tensor->type == GGML_TYPE_Q6_K) { + weights_per_block = 16; + } else { + weights_per_block = 32; + } + + OPENVINO_ASSERT(node_shape.back() % weights_per_block == 0, + "[load_gguf] tensor ", + tensor->name, + " has incompatible last dim shape: ", + node_shape.back()); + + ov::Tensor weights(weight_type, node_shape); + // For scales and biases + node_shape[node_shape.size() - 1] = node_shape[node_shape.size() - 1] / weights_per_block; + ov::Tensor scales(ov::element::f16, node_shape); + ov::Tensor biases(ov::element::f16, node_shape); + + ov::Output weight_node; + if (tensor->type == GGML_TYPE_Q4_0) { + extract_q4_0_data(tensor, weights, scales, biases); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q4_1) { + extract_q4_1_data(tensor, weights, scales, biases); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q8_0) { + extract_q8_0_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q6_K) { + extract_q6_k_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q4_K) { + extract_q4_k_data(tensor, weights, scales, biases); + weight_node = make_int4_weights(weights, scales, biases, weights_per_block); + } else if (tensor->type == GGML_TYPE_Q5_K) { + extract_q5_k_data(tensor, weights, scales, biases); + weight_node = make_int8_weights(weights, scales, biases, weights_per_block); + } + + OPENVINO_ASSERT(weight_node.get_shape().size() == 2, "Weight should be 2D"); + + weight_node.get_node_shared_ptr()->set_friendly_name(tensor->name); + return weight_node.get_node_shared_ptr(); +} + +void GgmlOvDecoder::dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename) { + std::ofstream file(filename); + if (!file.is_open()) { + std::cerr << "Failed to open file" << std::endl; + return; + } + + file << "=== GRAPH ===\n"; + + // clang-format off + file << "n_nodes = " << cgraph->n_nodes << "\n"; + file << " " << std::setw(3) << "nodes" + << std::setw(15) << "shape" + << std::setw(20) << "op" + << std::setw(20) << "name" + << std::setw(3) << " " + << std::setw(50) << "stride" + << "\n"; + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << ", " + << std::setw(5) << node->ne[2] << ", " + << std::setw(5) << node->ne[3] << "] " + << std::left << std::setw(20) << ggml_op_name(node->op) << std::right << " " + << std::left << std::setw(45) << node->name << std::right + << std::setw(2) << "[ " + << std::setw(0) << node->nb[0] << ", " + << std::setw(5) << node->nb[1] << ", " + << std::setw(5) << node->nb[2] << ", " + << std::setw(5) << node->nb[3] << "] " + << "\n"; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (auto* src = node->src[i]) { + file << std::setw(10) << " [ " + << std::setw(5) << src->ne[0] << ", " + << std::setw(5) << src->ne[1] << ", " + << std::setw(5) << src->ne[2] << ", " + << std::setw(5) << src->ne[3] << "] " + << std::setw(12) + << i << ": " << std::left << std::setw(12) << ggml_op_name(src->op) << std::right; + file << std::left << std::setw(30) << src->name << std::right + << std::setw(16) << "[ " + << std::setw(0) << src->nb[0] << ", " + << std::setw(5) << src->nb[1] << ", " + << std::setw(5) << src->nb[2] << ", " + << std::setw(5) << src->nb[3] << "] " + << "\n"; + } + } + } + + file << "n_leafs = " << cgraph->n_leafs << "\n"; + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << "] " + << std::setw(8) << ggml_op_name(node->op) << " " + << std::setw(16) << ggml_get_name(node) << "\n"; + } + // clang-format on + file << "========================================\n"; + + file.close(); +} + +void print_tensor_address_map(const struct ggml_cgraph* cgraph) { + std::map> address_map; + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto* node = cgraph->nodes[node_n]; + if (node->data) { + auto it = address_map.find(node->data); + if (it == address_map.end()) { + address_map[node->data] = std::vector(); + } + address_map[node->data].push_back(node->name); + } + } + for (const auto& pair : address_map) { + std::cout << "Address: " << pair.first << std::endl; + for (const auto& name : pair.second) { + std::cout << name << " ; "; + } + std::cout << std::endl << std::endl; + } +} + +std::vector GgmlOvDecoder::get_shape(const ggml_tensor* tensor) { + std::vector shape; + for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) { + shape.push_back(static_cast(tensor->ne[i])); + } + return shape; +} + +std::vector GgmlOvDecoder::get_stride(const ggml_tensor* tensor) { + std::vector stride; + for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) { + stride.push_back(static_cast(tensor->nb[i])); + } + return stride; +} + +ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor* tensor) { + switch (tensor->type) { + case GGML_TYPE_F64: + return ov::element::f64; + case GGML_TYPE_F32: + return ov::element::f32; + case GGML_TYPE_F16: + return ov::element::f16; + case GGML_TYPE_BF16: + return ov::element::bf16; + case GGML_TYPE_I8: + return ov::element::i8; + case GGML_TYPE_I16: + return ov::element::i16; + case GGML_TYPE_I32: + return ov::element::i32; + case GGML_TYPE_I64: + return ov::element::i64; + default: + return ov::element::dynamic; + } +} + +ov::PartialShape GgmlOvDecoder::get_input_shape(const std::string& name) const { + return ov::PartialShape(get_shape(m_inputs.at(name))); +} + +std::vector GgmlOvDecoder::get_input_stride(const std::string& name) const { + return get_stride(m_inputs.at(name)); +} + +ov::element::Type GgmlOvDecoder::get_input_type(const std::string& name) const { + return get_ov_type(m_inputs.at(name)); +} + +size_t GgmlOvDecoder::get_input_size() const { + return m_input_names.size(); +} + +std::string& GgmlOvDecoder::get_input_name(size_t index) const { + m_name = m_input_names[index]; + return m_name; +} + +std::vector GgmlOvDecoder::get_input_names() const { + return m_input_names; +} + +std::vector GgmlOvDecoder::get_output_stride(const std::string& name) const { + return get_stride(m_outputs.at(name)); +} + +ov::PartialShape GgmlOvDecoder::get_output_shape(const std::string& name) const { + return ov::PartialShape(get_shape(m_outputs.at(name))); +} + +ov::element::Type GgmlOvDecoder::get_output_type(const std::string& name) const { + return get_ov_type(m_outputs.at(name)); +} + +std::string& GgmlOvDecoder::get_output_name(size_t index) const { + m_name = std::string(m_output_names[index]); + return m_name; +} + +std::vector GgmlOvDecoder::get_output_names() const { + return m_output_names; +} + +const std::string& GgmlOvDecoder::get_op_name() const { + return m_op_name; +} + +int32_t* GgmlOvDecoder::get_input_op_params(const std::string& name) const { + return m_inputs.at(name)->op_params; +} + +int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const { + return m_outputs.at(name)->op_params; +} + +void GgmlOvDecoder::visit_subgraph(std::function)> node_visitor) const { + for (const auto& node : m_nodes) { + auto decoder = std::make_shared(node, + m_cgraph, + m_is_static, + m_is_first_token, + m_context_size, + m_context_size_swa, + m_num_heads, + m_num_heads_kv, + m_head_size, + m_swa_layers); + node_visitor(decoder); + } +} + +const std::string& GgmlOvDecoder::get_op_type() const { + static const std::map ops = { + {GGML_OP_NONE, "GGML_OP_NONE" }, + {GGML_OP_ACC, "GGML_OP_ACC" }, + {GGML_OP_ADD, "GGML_OP_ADD" }, + {GGML_OP_ADD1, "GGML_OP_ADD1" }, + {GGML_OP_CONT, "GGML_OP_CONT" }, + {GGML_OP_DIV, "GGML_OP_DIV" }, + {GGML_OP_DUP, "GGML_OP_DUP" }, + {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" }, + {GGML_OP_MUL, "GGML_OP_MUL" }, + {GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" }, + {GGML_OP_PERMUTE, "GGML_OP_PERMUTE" }, + {GGML_OP_RESHAPE, "GGML_OP_RESHAPE" }, + {GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" }, + {GGML_OP_ROPE, "GGML_OP_ROPE" }, + {GGML_OP_SCALE, "GGML_OP_SCALE" }, + {GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" }, + {GGML_OP_SUB, "GGML_OP_SUB" }, + {GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" }, + {GGML_OP_VIEW, "GGML_OP_VIEW" }, + {GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" }, + {GGML_OP_CPY, "GGML_OP_CPY" }, + {GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT"}, + }; + static const std::map unary_ops = { + {GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" }, + {GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" }, + {GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" }, + {GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" }, + {GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" }, + {GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" }, + {GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" }, + {GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" }, + {GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" }, + {GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" }, + {GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" }, + {GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" }, + {GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"}, + {GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" }, + {GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" } + }; + static const std::map glu_ops = { + {GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"}, + {GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" }, + {GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" } + }; + + switch (m_node->op) { + case GGML_OP_UNARY: + return unary_ops.at(ggml_get_unary_op(m_node)); + case GGML_OP_GLU: + return glu_ops.at(ggml_get_glu_op(m_node)); + default: + return ops.at(m_node->op); + } + static const std::string unknown_op = "UNKNOWN_GGML_OP"; + return unknown_op; +} diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h new file mode 100644 index 0000000000000..35e79ecefc724 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -0,0 +1,179 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ggml-quants.hpp" +#include "ggml.h" +#include "openvino/decoder.hpp" + +class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { +public: + // Graph decoder + GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map>& model_weights, + bool is_static, bool is_first_token); + + // Node decoder, called in GgmlOvDecoder::visit_subgraph + GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token, + int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size, + const std::vector& swa_layers); + + // Naive graph decoder + GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map>& model_weights); + + virtual ov::Any get_attribute(const std::string& name) const override { + return nullptr; + GGML_UNUSED(name); + } + + virtual ov::PartialShape get_input_shape(const std::string& name) const override; + + virtual std::vector get_input_stride(const std::string& name) const override; + + virtual ov::element::Type get_input_type(const std::string& name) const override; + + virtual size_t get_input_size() const override; + + virtual void get_input_node(size_t input_port_idx, + std::string& producer_name, + std::string& producer_output_port_name, + size_t& producer_output_port_index) const override { + GGML_UNUSED(input_port_idx); + GGML_UNUSED(producer_name); + GGML_UNUSED(producer_output_port_name); + GGML_UNUSED(producer_output_port_index); + } + + virtual std::string& get_input_name(size_t index) const override; + + virtual std::vector get_input_names() const override; + + virtual ov::PartialShape get_output_shape(const std::string& name) const override; + + virtual std::vector get_output_stride(const std::string& name) const override; + + virtual ov::element::Type get_output_type(const std::string& name) const override; + + virtual int32_t* get_input_op_params(const std::string& name) const override; + + virtual int32_t* get_output_op_params(const std::string& name) const override; + + virtual std::string& get_output_name(size_t index) const override; + + virtual std::vector get_output_names() const override; + + virtual const std::string& get_op_type() const override; + + virtual const std::string& get_op_name() const override; + + virtual void visit_subgraph(std::function)> node_visitor) const override; + + const ggml_tensor* get_input_ggml_tensor(const std::string& name) const { + return m_inputs.at(name); + } + + const ggml_tensor* get_output_ggml_tensor(const std::string& name) const { + return m_outputs.at(name); + } + + virtual int get_op_case() const override { + return m_op_case; + } + + virtual const std::map>& get_model_inputs() const override { + return m_model_inputs; + } + virtual const std::map>& get_model_extra_inputs() const override { + return m_model_extra_inputs; + } + virtual const std::map>& get_model_extra_input_values() const { + return m_model_extra_input_values; + } + virtual const std::map>& get_model_weights() const override { + return m_model_weights; + } + virtual const std::vector& get_model_output_names() const override { + return m_model_output_names; + } + + virtual int get_context_size() const override { return m_context_size; } + + virtual int get_context_size_swa() const override { return m_context_size_swa; } + + virtual int is_swa_layer(int layer) const override { + return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end(); + } + + virtual int get_num_heads() const override { return m_num_heads; } + + virtual int get_num_heads_kv() const override { return m_num_heads_kv; } + + virtual int get_head_size() const override { return m_head_size; } + + virtual int32_t* get_rope_params() const override { return m_rope_params; } + + virtual std::map get_kv_param_res_names() const override; + + virtual bool is_static() const override { return m_is_static; } + + virtual bool is_first_token() const override { return m_is_first_token; } + + ov::PartialShape get_graph_input_shape(const ggml_tensor* src) const; + + static void dump_cgraph(const struct ggml_cgraph* cgraph, std::string& filename); + + static std::shared_ptr create_weight_node(ggml_tensor* tensor, + std::optional requant_type = std::nullopt); + static std::map> create_weight_nodes( + struct ggml_cgraph* cgraph, std::map types_to_requantize = {}); + + const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const; + const ggml_tensor* get_tensor_from_name(const std::string& name) const; + + void clear_model_weights() { m_model_weights.clear(); } + +private: + void set_input_output(ggml_tensor* node, bool naive = false); + void add_extra_inputs(); + static std::vector get_shape(const ggml_tensor* tensor); + static std::vector get_stride(const ggml_tensor* tensor); + static ov::element::Type get_ov_type(const ggml_tensor* tensor); + + // set context_size, num_heads, etc + void set_llm_params(); + + struct ggml_cgraph* m_cgraph = nullptr; + ggml_tensor* m_node = nullptr; + std::vector m_nodes; + std::map m_inputs; + std::vector m_input_names; + std::map m_outputs; + std::vector m_output_names; + std::string m_op_name; + mutable std::string m_name; + int m_op_case = 0; + std::vector> m_op_node_name; + std::map> m_model_inputs; + std::map> m_model_extra_inputs; + std::map> m_model_extra_input_values; + std::map> m_model_weights; + std::vector m_model_output_names; + int m_context_size; + int m_context_size_swa; + std::vector m_swa_layers; + int m_num_heads; + int m_num_heads_kv; + int m_head_size; + int32_t* m_rope_params; + std::vector m_kv_names; + bool m_is_static = false; + bool m_is_first_token; +}; + +void print_tensor_address_map(const struct ggml_cgraph* cgraph); + +int extract_layer_from_name(const std::string& name); diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp new file mode 100644 index 0000000000000..648acb4e35ede --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -0,0 +1,578 @@ +#include "ggml-openvino.h" + +#include +#include +#include +#include +#include +#include + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-openvino/utils.h" +#include "ggml.h" + +#define GGML_OPENVINO_MAX_STREAMS 8 + +struct ggml_backend_openvino_context { + int device; // the device ID currently in use + std::string name; // context Name + std::string description; // context description + + // OpenVINO core components + ov::Core core; // OpenVINO core interface + std::shared_ptr model; // compiled Model + ov::InferRequest infer_request; // inference Request + + // OpenVINO Multi-stream support + static const int MAX_STREAMS = 8; // define the maximum number of flows + std::vector streams; // used to support multi-stream reasoning + int current_stream; // the currently active stream index + + // state Management + bool is_initialized; // initialize + + ggml_backend_openvino_context() + : device(0), name("OpenVINO"), description("OpenVINO Backend Context"), + current_stream(0), is_initialized(false) {} +}; + +static void ggml_backend_openvino_free(ggml_backend_t backend) { + ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *)backend->context; + delete ctx; + delete backend; +} + +static const char * ggml_backend_openvino_get_name(ggml_backend_t backend) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(backend); +} + +static enum ggml_status +ggml_backend_openvino_graph_compute(ggml_backend_t backend, struct ggml_cgraph *cgraph) { + openvino_frontend_compute(backend, cgraph); + + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_openvino_interface = { + /* .get_name = */ ggml_backend_openvino_get_name, + /* .free = */ ggml_backend_openvino_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_openvino_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +int ggml_backend_openvino_get_device_count() { + return ggml_openvino_info().device_count; +} + +static ggml_guid_t ggml_backend_openvino_guid(void) { + static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d }; + return &guid; +} + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { + if (device < 0 || device >= ggml_backend_openvino_get_device_count()) { + GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device); + return nullptr; + } + + ggml_backend_openvino_context * ctx = new ggml_backend_openvino_context; + if (ctx == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate context\n", __func__); + return nullptr; + } + + ggml_backend_t openvino_backend = new ggml_backend { + /* .guid = */ ggml_backend_openvino_guid(), + /* .interface = */ ggml_backend_openvino_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), device), + /* .context = */ ctx, + }; + + return openvino_backend; +} + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid()); +} + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) { + GGML_ASSERT(device >= 0); + return ggml_backend_cpu_buffer_type(); + GGML_UNUSED(device); +} + +// split tensor buffer that splits matrices by rows across multiple devices +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_split_buffer_type(const float * tensor_split) { + GGML_ASSERT(tensor_split != nullptr); + return nullptr; +} + +// pinned host buffer for use with the CPU backend for faster copies between CPU +// and GPU +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(void) { + return nullptr; +} + +struct ggml_backend_openvino_buffer_type_context { + int device; + std::string name; +}; + +static const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} +static bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; +} + + +static const char * ggml_backend_openvino_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return GGML_OPENVINO_NAME "_Split"; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_buft_is_openvino_split(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_split_buffer_type_get_name; +} + +struct ggml_backend_openvino_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_openvino_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_openvino_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context; + return ctx->description.c_str(); +} + +// TODO +static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(dev->context != nullptr); + GGML_ASSERT(free != nullptr); + GGML_ASSERT(total != nullptr); + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context; + GGML_ASSERT(ctx->device >= 0); + // ggml_openvino_set_device(ctx->device); + *total = 1; + *free = 1; +} + +static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_openvino_device_get_name(dev); + props->description = ggml_backend_openvino_device_get_description(dev); + props->type = ggml_backend_openvino_device_get_type(dev); + ggml_backend_openvino_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_OPENVINO_NO_PINNED") == nullptr; +#ifdef GGML_OPENVINO_NO_PEER_COPY + bool events = false; +#else + bool events = true; +#endif + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ events, + }; +} + +static ggml_backend_t ggml_backend_openvino_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context; + return ggml_backend_openvino_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context; + return ggml_backend_openvino_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_openvino_host_buffer_type(); +} + +static ggml_backend_buffer_t ggml_backend_openvino_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static ggml_backend_buffer_t ggml_backend_openvino_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool is_op_unsupported_case(const ggml_tensor* op) { + if (op->op == GGML_OP_SOFT_MAX) { + if (op->src[2] != nullptr) { + GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + const auto* op_params = op->op_params; + memcpy(&scale, (const float*) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + return true; + } + } + + if (op->op == GGML_OP_FLASH_ATTN_EXT) { + if (op->src[4] != nullptr) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + const auto* op_params = op->op_params; + memcpy(&scale, (const float*) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float)); + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + return true; + } + if (logit_softcap != 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); + return true; + } + } + + if (op->op == GGML_OP_PERMUTE) { + if (op->type == GGML_TYPE_BF16) { + // err msg: [GPU] Could not find a suitable kernel for transpose + GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n"); + return true; + } + } + + if (op->op == GGML_OP_CPY) { + if (op->src[1] != op) { + GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n"); + return true; + } + } + + if (op->op == GGML_OP_MUL_MAT) { + if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) { + // Has accuracy issue, try enabling this and see `test-backend-ops -o "MUL_MAT"` + GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); + return true; + } + } + + if (op->op == GGML_OP_ROPE) { + const int32_t* op_params = op->op_params; + const int n_dims = op_params[1]; + const int mode = op_params[2]; + if (mode == GGML_ROPE_TYPE_MROPE || mode == GGML_ROPE_TYPE_VISION) { + GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); + return true; + } + if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) { + GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", + n_dims, + op->src[0]->ne[0]); + return true; + } + if (op->type != GGML_TYPE_F32) { + GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); + return true; + } + float freq_scale; + float ext_factor; + memcpy(&freq_scale, op_params + 6, sizeof(float)); + memcpy(&ext_factor, op_params + 7, sizeof(float)); + if (ext_factor != 0.0f) { + GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); + return true; + } + if (op->src[0]->op == GGML_OP_VIEW) { + if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { + GGML_LOG_WARN( + "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] %ld\n", + op->src[0]->view_src->ne[1], + op->src[0]->ne[2]); + return true; + } + } + } + return false; +} + +static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { + GGML_ASSERT(dev->reg != nullptr); + + static std::set supported_types{GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_I64, + GGML_TYPE_I32, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q8_0, + GGML_TYPE_Q6_K}; + + static const std::set supported_ops{GGML_OP_NONE, + GGML_OP_ADD, + GGML_OP_MUL, + GGML_OP_MUL_MAT, + GGML_OP_VIEW, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_ROPE, + GGML_OP_RMS_NORM, + GGML_OP_SCALE, + // softmax is not updated due to replaced by flash_attn_ext + // GGML_OP_SOFT_MAX, + GGML_OP_SET_ROWS, + GGML_OP_FLASH_ATTN_EXT, + GGML_OP_CPY}; + static const std::set supported_unary_ops{ + GGML_UNARY_OP_SILU, + }; + static const std::set supported_glu_ops{ + GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU, + }; + + switch (op->op) { + case GGML_OP_UNARY: { + auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end(); + if (!supported) { + GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); + return false; + } + break; + } + case GGML_OP_GLU: { + auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end(); + if (!supported) { + GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); + return false; + } + break; + } + default: { + auto supported = supported_ops.find(op->op) != supported_ops.end(); + if (!supported) { + GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op)); + return false; + } + } + } + + if (supported_types.find(op->type) == supported_types.end()) { + GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); + return false; + } + if (op->ne[3] != 1) { + GGML_LOG_WARN("OpenVINO backend does not support tensors with ne[3] != 1\n"); + return false; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto* src = op->src[i]; + if (src == nullptr) { + break; + } + if (supported_types.find(src->type) == supported_types.end()) { + GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); + return false; + } + if (src->ne[3] != 1) { + GGML_LOG_WARN("OpenVINO backend does not support tensors with ne[3] != 1\n"); + return false; + } + if (ggml_is_quantized(src->type) && src->ne[2] != 1) { + GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); + return false; + } + } + + if (is_op_unsupported_case(op)) { + return false; + } + return true; +} + +static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_openvino_device_interface = { + /* .get_name = */ ggml_backend_openvino_device_get_name, + /* .get_description = */ ggml_backend_openvino_device_get_description, + /* .get_memory = */ ggml_backend_openvino_device_get_memory, + /* .get_type = */ ggml_backend_openvino_device_get_type, + /* .get_props = */ ggml_backend_openvino_device_get_props, + /* .init_backend = */ ggml_backend_openvino_device_init, + /* .get_buffer_type = */ ggml_backend_openvino_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_openvino_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_openvino_device_supports_op, + /* .supports_buft = */ ggml_backend_openvino_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +struct ggml_backend_openvino_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_openvino_reg_get_name(ggml_backend_reg_t reg) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(reg); +} + +static size_t ggml_backend_openvino_reg_get_device_count(ggml_backend_reg_t reg) { + return ggml_openvino_info().device_count; + GGML_UNUSED(reg); + + // TODO + ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *)reg->context; + + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_openvino_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + // GGML_ASSERT(index == 0); + + // static ggml_backend_device ggml_backend_openvino_device = { + // /* .iface = */ ggml_backend_openvino_device_interface, + // /* .reg = */ reg, + // /* .context = */ nullptr, + // }; + + // return &ggml_backend_openvino_device; + + // GGML_UNUSED(reg); + // GGML_UNUSED(index); +} + +static void * ggml_backend_openvino_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { + return (void *)ggml_backend_openvino_split_buffer_type; + } + // if (strcmp(name, "ggml_backend_register_host_buffer") == 0) { + // return (void *)ggml_backend_openvino_register_host_buffer; + // } + // if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) { + // return (void *)ggml_backend_openvino_unregister_host_buffer; + // } + return nullptr; +} + +static const struct ggml_backend_reg_i ggml_backend_openvino_reg_interface = { + /* .get_name = */ ggml_backend_openvino_reg_get_name, + /* .get_device_count = */ ggml_backend_openvino_reg_get_device_count, + /* .get_device = */ ggml_backend_openvino_reg_get_device, + /* .get_proc_address = */ ggml_backend_openvino_get_proc_address, +}; + +static int get_openvino_device_count() { + ov::Core core; + auto devices = core.get_available_devices(); + // return devices.size(); + return 1; +} + +static ggml_openvino_device_info ggml_openvino_init() { + ggml_openvino_device_info info = {}; + // TODO + info.device_count = get_openvino_device_count(); + return info; +} + +const ggml_openvino_device_info & ggml_openvino_info() { + static ggml_openvino_device_info info = ggml_openvino_init(); + return info; +} + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void) { + static ggml_backend_reg reg; + + static bool initialized = false; + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_openvino_reg_context * ctx = new ggml_backend_openvino_reg_context; + + // GGML_LOG_DEBUG("ggml_openvino_info().device_count = %d \n", ggml_openvino_info().device_count); + for (int i = 0; i < ggml_openvino_info().device_count; i++) { + ggml_backend_openvino_device_context * dev_ctx = new ggml_backend_openvino_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_OPENVINO_NAME + std::to_string(i); + + // ggml_openvino_set_device(i); + dev_ctx->description = ov::get_openvino_version().description; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .interface = */ ggml_backend_openvino_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_openvino_reg_interface, + /* .context = */ ctx }; + } + + initialized = true; + } + + return ® +} diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp new file mode 100644 index 0000000000000..1538a8207ca9d --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -0,0 +1,554 @@ +#include "ggml-quants.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-common.h" +#include "ggml-impl.h" +#include "ggml.h" + +void unpack_32_4(const uint8_t* data, uint8_t* dst) { + std::fill_n(dst, 16, 0); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j] & 0x0F); + uint8_t y = (data[j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[j / 2] |= x; + dst[8 + j / 2] |= y; // Last 16 weights are in the higher bits + } +} + +// Extracts (weight, scales, biases) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +void extract_q4_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); + biases[i] = ov::float16(-8.f * static_cast(scales[i])); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); +} + +// Extracts (weight, scales, biases) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. +void extract_q4_1_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block))); + biases[i] = ov::float16::from_bits(*((uint16_t*)(data + i * bytes_per_block + 2))); + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); + }); +} + +// Extracts (weight, scales, biases) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +void extract_q8_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t*) block_data); + biases[i] = ov::float16(-128.f * static_cast(scales[i])); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. + // Original data is in int8_t, so we add a bias of -128 and invert the first bit. + x ^= 1 << 7; + weights[i * weights_per_block + j] = x; + } + }); +} + +void unpack_256_4(const uint8_t* data, uint8_t* dst) { + // Initialize the output array with zeros + std::fill_n(dst, 128, 0); + + for (size_t i = 0; i < 4; ++i) { + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[i * 32 + j] & 0x0F); + uint8_t y = (data[i * 32 + j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[i * 32 + j / 2] |= x; + dst[i * 32 + 16 + j / 2] |= y; // Last 16 weights are in the higher bits + } + } +} + +void extract_q4_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 2 + 2 + 12 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + + // Extract scale factors and offsets + float scale_scales = static_cast(ov::float16::from_bits(*((uint16_t*)block_data))); + float scale_biases = static_cast(ov::float16::from_bits(*((uint16_t*)block_data + 1))); + + // Extract qs1 and qs2 + uint8_t* qs1 = block_data + 4; + // uint8_t* qs2 = block_data + 16; + + scales[i * 8] = ov::float16(scale_scales * static_cast((*(qs1) & 0b111111))); + scales[i * 8 + 1] = ov::float16(scale_scales * static_cast((*(qs1 + 1) & 0b111111))); + scales[i * 8 + 2] = ov::float16(scale_scales * static_cast((*(qs1 + 2) & 0b111111))); + scales[i * 8 + 3] = ov::float16(scale_scales * static_cast((*(qs1 + 3) & 0b111111))); + scales[i * 8 + 4] = + ov::float16(scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4))); + scales[i * 8 + 5] = + ov::float16(scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4))); + scales[i * 8 + 6] = + ov::float16(scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4))); + scales[i * 8 + 7] = + ov::float16(scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4))); + + biases[i * 8] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 4) & 0b111111))); + biases[i * 8 + 1] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 5) & 0b111111))); + biases[i * 8 + 2] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 6) & 0b111111))); + biases[i * 8 + 3] = ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 7) & 0b111111))); + biases[i * 8 + 4] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4))); + biases[i * 8 + 5] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4))); + biases[i * 8 + 6] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4))); + biases[i * 8 + 7] = + ov::float16(-1.f * scale_biases * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4))); + unpack_256_4(block_data + 16, weights + i * 128); + }); +} + +void extract_q6_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 128 + 64 + 16 + 2; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + + float scale_factor = + static_cast(ov::float16::from_bits(*((uint16_t*) block_data + 104))); // (128+64+16)/2 + + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t*) (block_data + 128 + 64 + j)))); + biases[j + i * 16] = ov::float16(-32.f * static_cast(scales[j + i * 16])); + } + + uint8_t* ql = block_data; + uint8_t* qh = block_data + 128; + + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); +} + +static inline void get_scale_min_k4(int j, const uint8_t* q, uint8_t* d, uint8_t* m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +void extract_q5_k_data(const ggml_tensor* tensor, ov::Tensor& weights_arr, ov::Tensor& scales_arr, + ov::Tensor& biases_arr) { + const uint64_t bytes_per_block = 4 + 12 + 32 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + auto* data = static_cast(tensor->data); + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t* block_data = data + i * bytes_per_block; + + const float d = static_cast(ov::float16::from_bits(*((uint16_t*) block_data))); + const float min = static_cast(ov::float16::from_bits(*((uint16_t*) block_data + 1))); + + const uint8_t* scales_data = block_data + 4; // 12 bytes of scales + const uint8_t* qh = block_data + 4 + 12; // 32 bytes of high bits + const uint8_t* ql = block_data + 4 + 12 + 32; // 128 bytes of low bits + + int is = 0; + uint8_t u1 = 1; + uint8_t u2 = 2; + + // Process 2 blocks in one iteration + for (int j = 0; j < 256; j += 64) { // 256 = QK_K, so 4 iterations of 64 + uint8_t sc; + uint8_t m; + + // Get scale and min for first 32 elements + get_scale_min_k4(is + 0, scales_data, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + // Get scale and min for second 32 elements + get_scale_min_k4(is + 1, scales_data, &sc, &m); + const float d2 = d * sc; + const float m2 = min * m; + + scales[i * 8 + is] = ov::float16(d1); + biases[i * 8 + is] = ov::float16(-m1); + scales[i * 8 + is + 1] = ov::float16(d2); + biases[i * 8 + is + 1] = ov::float16(-m2); + + // Extract weights for first 32 elements (matching deq formula exactly) + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0); + } + + // Extract weights for second 32 elements + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0); + } + + ql += 32; + is += 2; + u1 <<= 2; + u2 <<= 2; + } + }); +} + +// TODO Reorder for make_intX_weights + +ov::Output make_int8_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { + ov::Shape orig_shape = weight.get_shape(); + + // Expand dimensions for scales and biases + auto scale_shape = scales.get_shape(); + + ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + biases.set_shape(scale_shape); + } + + // Create graph nodes + auto weights_node = std::make_shared( + ov::element::u8, packed_shape, static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto scales_f16 = std::make_shared(scales); + ov::Tensor biases_u8(ov::element::u8, scale_shape); + + // Calculate zero point + const ov::float16* bias_data = biases.data::value_type>(); + const ov::float16* scale_data = scales.data::value_type>(); + uint8_t* bias_u8_data = biases_u8.data(); + for (size_t i = 0; i < biases_u8.get_size(); ++i) { + bias_u8_data[i] = (uint8_t)std::round(-1.f * static_cast(bias_data[i]) / static_cast(scale_data[i])); + } + + auto zero_point = std::make_shared(biases_u8); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + + // Quantization operations + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + + auto w_zp = std::make_shared( + weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY + ); + ov::Output w_zp_s = + std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = + std::make_shared(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape); + w_zp_s = std::make_shared(w_zp_s, final_shape, false); + } + + return std::make_shared(w_zp_s, ov::element::f32); +} + +ov::Output make_int4_weights(ov::Tensor& weight, ov::Tensor& scales, ov::Tensor& biases, size_t group_size) { + ov::Shape orig_weight_shape = weight.get_shape(); + + // Expand dimensions for scales and biases + ov::Shape scale_bias_shape = scales.get_shape(); + + // Create INT4 weight tensor + ov::Shape packed_shape = { + orig_weight_shape[0], + orig_weight_shape[1] / group_size, + group_size + }; + + // Requantized channel-wise case + if (packed_shape[1] == 1) { + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_bias_shape.push_back(1); + scales.set_shape(scale_bias_shape); + biases.set_shape(scale_bias_shape); + } + + auto weights_node = std::make_shared(ov::element::u4, packed_shape, static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + // Pack zero points: two subsequent values into one + const ov::float16* bias_data = biases.data::value_type>(); + const ov::float16* scale_data = scales.data::value_type>(); + ov::Tensor zero_point_tensor(ov::element::u4, scale_bias_shape); + uint8_t* zero_point_data = static_cast(zero_point_tensor.data()); + for (size_t i = 0; i < zero_point_tensor.get_byte_size(); ++i) { + uint8_t bias1 = (uint8_t)std::round(-1.f * static_cast(bias_data[i * 2]) / static_cast(scale_data[i * 2])); + uint8_t bias2 = (uint8_t)std::round(-1.f * static_cast(bias_data[i * 2 + 1]) / static_cast(scale_data[i * 2 + 1])); + zero_point_data[i] = (bias2 << 4) | (bias1 & 0x0F); + } + + auto zero_points_node = std::make_shared(zero_point_tensor); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + + auto scales_f16 = std::make_shared(scales); + + // Perform dequantization + auto w_zp = std::make_shared( + weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + + ov::Output w_zp_s = + std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = std::make_shared( + ov::element::i64, ov::Shape{orig_weight_shape.size()}, orig_weight_shape); + + w_zp_s = std::make_shared(w_zp_s, final_shape, false); + } + + return std::make_shared(w_zp_s, ov::element::f32); +} + +std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type) { + std::vector weights_f32(tensor->ne[0] * tensor->ne[1]); + ggml_get_type_traits(tensor->type)->to_float(tensor->data, weights_f32.data(), ggml_nelements(tensor)); + + std::shared_ptr weight_node; + ov::Shape node_shape = {(uint64_t) (tensor->ne[1]), (uint64_t) (tensor->ne[0])}; + + if (requant_type == ExtraQuantType::F16) { + ov::Tensor weights(ov::element::f16, node_shape); + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), ggml_nelements(tensor)); + std::shared_ptr weight_node = std::make_shared(weights); + weight_node->set_friendly_name(tensor->name); + return weight_node; + } + + int64_t block_size = node_shape[1]; + if (requant_type == ExtraQuantType::Q4_0_128) { + block_size = 128; + } else if (requant_type == ExtraQuantType::Q8_0_32) { + block_size = 32; + } + auto scales_shape = ov::Shape{node_shape[0], node_shape[1] / block_size}; + + ov::Tensor weights; + ov::Tensor scales(ov::element::f16, scales_shape); + ov::Tensor bias(ov::element::f16, scales_shape); + + if (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128) { + weights = ov::Tensor(ov::element::u4, node_shape); + quantize_q4_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int4_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } else if (requant_type == ExtraQuantType::Q8_1_C) { + weights = ov::Tensor(ov::element::u8, node_shape); + quantize_q8_1(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int8_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } else if (requant_type == ExtraQuantType::Q8_0_C || requant_type == ExtraQuantType::Q8_0_32) { + weights = ov::Tensor(ov::element::u8, node_shape); + quantize_q8_0(weights_f32.data(), weights, scales, bias, weights.get_size(), block_size); + weight_node = make_int8_weights(weights, scales, bias, block_size).get_node_shared_ptr(); + } + + weight_node->set_friendly_name(tensor->name); + return weight_node; +} + +void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + biases[i] = ov::float16(-8.f * d); + + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } + } +} + +void quantize_q8_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + biases[i] = ov::float16(-128.0f * d); + + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } + } +} + +void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto* weights = static_cast(weights_arr.data()); + auto* scales = scales_arr.data::value_type>(); + auto* biases = biases_arr.data::value_type>(); + for (int i = 0; i < nb; i++) { + float min = std::numeric_limits::max(); + float max = std::numeric_limits::lowest(); + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + const float d = (max - min) / ((1 << 8) - 1); + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + biases[i] = ov::float16(min); + + for (int j = 0; j < qk; ++j) { + const float x0 = (x[i * qk + j] - min) * id; + const uint8_t xi0 = roundf(x0); + weights[i * qk + j] = xi0; + } + } +} diff --git a/ggml/src/ggml-openvino/ggml-quants.hpp b/ggml/src/ggml-openvino/ggml-quants.hpp new file mode 100644 index 0000000000000..71ae317a39e90 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.hpp @@ -0,0 +1,74 @@ +#pragma once +#include +#include +#include + +#include "ggml.h" + +void unpack_32_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q4_1_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q8_0_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void unpack_256_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q5_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +void extract_q6_k_data(const ggml_tensor* tensor, + ov::Tensor& weights_arr, + ov::Tensor& scales_arr, + ov::Tensor& biases_arr); + +static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32; + +ov::Output make_int8_weights(ov::Tensor& weight, + ov::Tensor& scales, + ov::Tensor& biases, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); + +ov::Output make_int4_weights(ov::Tensor& weight, + ov::Tensor& scales, + ov::Tensor& biases, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE); + +enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; + +std::shared_ptr requantize(const ggml_tensor* tensor, ExtraQuantType requant_type); + +void quantize_q4_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk); +void quantize_q8_1(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk); +void quantize_q8_0(const float* x, ov::Tensor& weights_arr, ov::Tensor& scales_arr, ov::Tensor& biases_arr, int64_t k, + int64_t qk); + +namespace ov { +namespace op { +namespace util { +// From /src/common/transformations/include/transformations/utils/utils.hpp +bool get_single_value(const std::shared_ptr& const_node, + float& value, + bool check_value_range = true); +} // namespace util +} // namespace op +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp new file mode 100644 index 0000000000000..6f11ff1283e37 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +class GgmlDecoder : public DecoderBase { +public: + virtual ov::Any get_attribute(const std::string& name) const = 0; + + virtual PartialShape get_input_shape(const std::string& name) const = 0; + + virtual std::vector get_input_stride(const std::string& name) const = 0; + + virtual element::Type get_input_type(const std::string& name) const = 0; + + virtual size_t get_input_size() const = 0; + + virtual void get_input_node(size_t input_port_idx, + std::string& producer_name, + std::string& producer_output_port_name, + size_t& producer_output_port_index) const = 0; + + virtual std::string& get_input_name(size_t index) const = 0; + + virtual std::vector get_input_names() const = 0; + + virtual PartialShape get_output_shape(const std::string& name) const = 0; + + virtual std::vector get_output_stride(const std::string& name) const = 0; + + virtual element::Type get_output_type(const std::string& name) const = 0; + + virtual int32_t* get_input_op_params(const std::string& name) const = 0; + + virtual int32_t* get_output_op_params(const std::string& name) const = 0; + + virtual std::string& get_output_name(size_t index) const = 0; + + virtual std::vector get_output_names() const = 0; + + virtual const std::string& get_op_type() const = 0; + + virtual const std::string& get_op_name() const = 0; + + virtual void visit_subgraph(std::function)> node_visitor) const = 0; + + virtual int get_op_case() const = 0; + + virtual const std::map>& get_model_inputs() const = 0; + virtual const std::map>& get_model_extra_inputs() const = 0; + virtual const std::map>& get_model_weights() const = 0; + virtual const std::vector& get_model_output_names() const = 0; + + virtual int get_num_heads() const = 0; + virtual int get_num_heads_kv() const = 0; + virtual int get_head_size() const = 0; + virtual int32_t* get_rope_params() const = 0; + virtual std::map get_kv_param_res_names() const = 0; + + virtual bool is_static() const = 0; + virtual bool is_first_token() const = 0; + virtual int get_context_size() const = 0; + virtual int get_context_size_swa() const = 0; + virtual int is_swa_layer(int layer) const = 0; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.cpp b/ggml/src/ggml-openvino/openvino/frontend.cpp new file mode 100644 index 0000000000000..dbdae1ed45ca1 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.cpp @@ -0,0 +1,27 @@ +#include "frontend.hpp" + +#include "input_model.hpp" +#include "op_table.hpp" +#include "translate_session.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +FrontEnd::FrontEnd() {} + +std::shared_ptr FrontEnd::convert(const InputModel::Ptr& model, bool naive) { + auto ggml_model = std::dynamic_pointer_cast(model); + FRONT_END_GENERAL_CHECK(ggml_model, "Invalid input model"); + std::shared_ptr converted_model; + const auto& supported_ops = get_supported_ops(); + { + TranslateSession translate_session(model, supported_ops, naive); + converted_model = translate_session.get_converted_model(); + } + return converted_model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.hpp b/ggml/src/ggml-openvino/openvino/frontend.hpp new file mode 100644 index 0000000000000..f1c6f0c3e3ce3 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd { +public: + using Ptr = std::shared_ptr; + FrontEnd(); + + static std::shared_ptr convert(const InputModel::Ptr& model, bool naive = false); +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.cpp b/ggml/src/ggml-openvino/openvino/input_model.cpp new file mode 100644 index 0000000000000..5fb16ea2db87d --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.cpp @@ -0,0 +1,17 @@ +#include "input_model.hpp" + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +InputModel::InputModel(const std::shared_ptr& gdecoder) : m_decoder(gdecoder) {} + +const std::shared_ptr& InputModel::get_model_decoder() const { + return m_decoder; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.hpp b/ggml/src/ggml-openvino/openvino/input_model.hpp new file mode 100644 index 0000000000000..9bc9a28e9aeca --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd; +class GgmlDecoder; +using ov::frontend::ggml::GgmlDecoder; + +class InputModel : public ov::frontend::InputModel { + friend class ::ov::frontend::ggml::FrontEnd; + +public: + explicit InputModel(const std::shared_ptr& gdecoder); + + const std::shared_ptr& get_model_decoder() const; + +private: + std::shared_ptr m_decoder; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/node_context.hpp b/ggml/src/ggml-openvino/openvino/node_context.hpp new file mode 100644 index 0000000000000..a64ae098ab3e9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/node_context.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession; + +typedef std::map> TensorMap; + +class NodeContext : public frontend::NodeContext { +public: + NodeContext(const std::shared_ptr& decoder, + std::shared_ptr& tensor_map, + TranslateSession* translate_session = nullptr) + : ov::frontend::NodeContext(decoder->get_op_type()), + m_decoder(decoder), + m_tensor_map(tensor_map), + m_translate_session(translate_session) { + m_input_names = decoder->get_input_names(); + m_output_names = decoder->get_output_names(); + } + + TranslateSession* get_translate_session() const { + return m_translate_session; + } + + const std::vector& get_input_names() const { return m_input_names; } + + size_t get_input_size() const override { + return m_decoder->get_input_size(); + } + + ov::element::Type get_input_type(size_t index) const { + return m_decoder->get_input_type(m_input_names[index]); + } + + PartialShape get_input_shape(size_t index) const { + return m_decoder->get_input_shape(m_input_names[index]); + } + + std::vector get_input_stride(size_t index) const { + return m_decoder->get_input_stride(m_input_names[index]); + } + + std::string get_output_name() const { return m_output_names[0]; } + + PartialShape get_output_shape(size_t index) const { + return m_decoder->get_output_shape(m_output_names[index]); + } + + std::vector get_output_stride(size_t index) const { + return m_decoder->get_output_stride(m_output_names[index]); + } + + int32_t* get_input_op_params(size_t index) const { + return m_decoder->get_input_op_params(m_input_names[index]); + } + + int32_t* get_output_op_params(size_t index) const { + return m_decoder->get_output_op_params(m_output_names[index]); + } + + ov::element::Type get_output_type(size_t index) const { + return m_decoder->get_output_type(m_output_names[index]); + } + + Output get_input(int idx) const override { + return m_tensor_map->at(m_decoder->get_input_name(idx)); + } + + Output get_input(const std::string& name) const override { + if (m_tensor_map->find(name) == m_tensor_map->end()) { + throw std::runtime_error("'" + name + "' not found in tensor map."); + } + return m_tensor_map->at(name); + } + + bool has_input(const std::string& name) const { + return m_tensor_map->find(name) != m_tensor_map->end(); + } + + const std::string& get_name() const override { + return m_decoder->get_op_name(); + } + + ov::Any get_attribute_as_any(const std::string& name) const override { + return m_decoder->get_attribute(name); + } + + int get_op_case() const { + return m_decoder->get_op_case(); + } + bool is_static() const { + return m_decoder->is_static(); + } + bool is_first_token() const { + return m_decoder->is_first_token(); + } + +private: + std::shared_ptr m_decoder; + std::shared_ptr& m_tensor_map; + TranslateSession* m_translate_session; + std::vector m_input_names; + std::vector m_output_names; +}; + +using CreatorFunction = std::function; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp new file mode 100644 index 0000000000000..9ae0f420ccb2f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -0,0 +1,49 @@ + +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cont(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case"); + + auto src_shape = context.get_input_shape(0).to_shape(); + auto dst_shape = context.get_output_shape(0).to_shape(); + ov::Output res; + + if (op_case == 1) { + // The input comes from a PERMUTE + dst_shape[1] = -1; + res = std::make_shared( + context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), + false); + } else if (op_case == 2) { + // The input comes from a TRANSPOSE + return {context.get_input(0)}; + } else { + // The input comes from a VIEW + res = process_view_input(context, 0); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp new file mode 100644 index 0000000000000..54b49018a9699 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp @@ -0,0 +1,20 @@ +#include +#include +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cpy(const NodeContext& context) { + auto res = std::make_shared(context.get_input(0), context.get_output_type(0)); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp new file mode 100644 index 0000000000000..8b67778fb9373 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_flash_attn_ext(const NodeContext& context) { + num_inputs_check(context, 4, 4); + auto q_f32 = context.get_input(0); + auto k = context.get_input(1); + auto v = context.get_input(2); + auto mask = context.get_input(3); + + float* params = reinterpret_cast(context.get_output_op_params(0)); + float scale = params[0]; + // float max_bias = params[1]; + // float logit_softcap = params[2]; + + auto q = std::make_shared(q_f32, ov::element::f16); + auto scale_node = std::make_shared(ov::element::f16, ov::Shape{}, std::vector{scale}); + + ov::Output mask_sliced; + std::string mask_name = "KQ_mask_sliced"; + if (context.get_input_names()[3].find("swa") != std::string::npos) { + mask_name = "KQ_mask_swa_sliced"; + } + if (context.has_input(mask_name)) { + mask_sliced = context.get_input(mask_name); + } else { + auto token_len = get_dimensions(q, {1}); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_sliced = std::make_shared(mask, zero, token_len, one, one); + } + + if (mask_sliced.get_element_type() != ov::element::f16) { + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + } + + auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output kv) { + int64_t factor = q_batch / kv_batch; + if (factor > 1) { + auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{q_batch}); + auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{kv_batch}); + auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor}); + + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); + auto kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); + + auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2}); + auto kv_broadcast_shape = + std::make_shared(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0); + kv = std::make_shared(kv_unsqueezed, kv_broadcast_shape); + + auto new_kv_shape = + std::make_shared(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0); + kv = std::make_shared(kv, new_kv_shape, false); + } + return kv; + }; + + auto q_shape = context.get_input_shape(0).to_shape(); + auto k_shape = context.get_input_shape(1).to_shape(); + k = tile_kv(q_shape[0], k_shape[0], k); + v = tile_kv(q_shape[0], k_shape[0], v); + + auto sdpa = std::make_shared(q, k, v, mask_sliced, scale_node, false); + auto sdpa_f32 = std::make_shared(sdpa, ov::element::f32); + auto res = std::make_shared(sdpa_f32, + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp new file mode 100644 index 0000000000000..5e4c7d901ac32 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_get_rows(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + Output res; + auto data = context.get_input(0); + auto indices = context.get_input(1); + + if (op_case == 2) { + // The input comes from a VIEW + indices = process_view_input(context, 1); + } + + // data[b,x,y] ind[1,b,x'] test-backend-ops case + // data[x,y] ind[1,1,x'] normal case + indices = std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + if (data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + res = std::make_shared(data, indices, axis, 1); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared(data, indices, axis); + } + + if (res.get_element_type() != context.get_output_type(0)) { + res = std::make_shared(res, context.get_output_type(0)); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp new file mode 100644 index 0000000000000..4295bf7517c3c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_geglu(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + auto combined = context.get_input(0); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2}); + auto split = std::make_shared(combined, split_axis, 2); + src0 = split->output(0); + src1 = split->output(1); + } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto gelu = std::make_shared(src0); + auto res = std::make_shared(gelu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp new file mode 100644 index 0000000000000..bef42fe4b70c0 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_swiglu(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + auto combined = context.get_input(0); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2}); + auto split = std::make_shared(combined, split_axis, 2); + src0 = split->output(0); + src1 = split->output(1); + } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto sigmoid = std::make_shared(src0); + auto silu = std::make_shared(src0, sigmoid); + auto res = std::make_shared(silu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp new file mode 100644 index 0000000000000..b4103378ebb1b --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -0,0 +1,88 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_mulmat(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + ov::Output res; + ov::Output B = context.get_input(0); + ov::Output A = context.get_input(1); + + bool transpose_b = true; + if (op_case == 2) { + B = B.get_node_shared_ptr()->input_value(0); + transpose_b = false; + } else if (op_case == 3) { + B = process_view_input(context, 0); + A = process_view_input(context, 1); + } + if (A.get_element_type() != B.get_element_type()) { + B = std::make_shared(context.get_input(0), context.get_input_type(1)); + } + + auto B_shape = context.get_input_shape(0).to_shape(); + auto A_shape = context.get_input_shape(1).to_shape(); + int64_t A_batch = A_shape[0]; + int64_t B_batch = B_shape[0]; + auto A_batch_larger = A_batch > B_batch; + Output Z = A_batch_larger ? B : A; + int64_t factor = A_batch_larger ? A_batch / B_batch : B_batch / A_batch; + if (factor > 1) { + auto A_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{A_batch}); + auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{B_batch}); + auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor}); + + auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2}); + + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); + auto Z_unsqueezed = std::make_shared(Z, unsqueeze_axes); + + Output batch_small = A_batch_larger ? B_batch_node : A_batch_node; + Output batch_large = A_batch_larger ? A_batch_node : B_batch_node; + auto broadcast_shape = + std::make_shared(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0); + auto Z_broadcasted = std::make_shared(Z_unsqueezed, broadcast_shape); + + auto new_Z_shape = std::make_shared(ov::OutputVector{batch_large, Z_last_two_dims}, 0); + Z = std::make_shared(Z_broadcasted, new_Z_shape, false); + } + if (A_batch_larger) { + B = Z; + } else { + A = Z; + } + + res = std::make_shared(A, B, false, transpose_b); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp new file mode 100644 index 0000000000000..086b1e4cdb172 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_permute(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported PERMUTE case"); + ov::Output res; + + if (op_case == 1) { + res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + } else { + auto src = context.get_input(0); + Output attention_size; + if (context.is_static()) { + attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX}); + } else if (op_case == 2) { + attention_size = context.get_input("attention_size"); + } else { + attention_size = context.get_input("attention_size_swa"); + } + + auto src_shape_ = context.get_input_shape(0).to_shape(); + std::vector src_shape(src_shape_.begin(), src_shape_.end()); + + auto src_reshaped = std::make_shared( + src, + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, src_shape[1], src_shape[2]}), + false); + + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto src_slice = std::make_shared(src_reshaped, zero, attention_size, one, zero); + + res = std::make_shared(src_slice, + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp new file mode 100644 index 0000000000000..1ed6f4b880b0a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_reshape(const NodeContext& context) { + num_inputs_check(context, 1, 1); + if (context.get_input_shape(0) == context.get_output_shape(0)) { + return {context.get_input(0)}; + } + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4, + "Unsupported RESHAPE case"); + + auto output_shape = context.get_output_shape(0).to_shape(); + std::shared_ptr new_shape_node; + if (op_case == 1) { + new_shape_node = + ov::op::v0::Constant::create(ov::element::i64, + {3}, + std::vector{-1, (int64_t)output_shape[1], (int64_t)output_shape[2]}); + } else if (op_case == 2) { + new_shape_node = + ov::op::v0::Constant::create(ov::element::i64, + {3}, + std::vector{(int64_t)output_shape[0], -1, (int64_t)output_shape[2]}); + } else if (op_case == 3) { + new_shape_node = + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{(int64_t) output_shape[0], -1, 1}); + } else if (op_case == 4) { + return {context.get_input(0).get_node_shared_ptr()->input_value(0)}; + } + auto res = std::make_shared(context.get_input(0), new_shape_node, false); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp new file mode 100644 index 0000000000000..c9df4c42f3e0d --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rms_norm(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto input_node = context.get_input(0); + auto square = std::make_shared( + input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f})); + + auto mean = std::make_shared( + square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true); + + float eps; + memcpy(&eps, context.get_output_op_params(0), sizeof(float)); + + auto rms = std::make_shared( + std::make_shared(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps}))); + + auto reciprocal = + std::make_shared(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms); + + auto res = std::make_shared(input_node, reciprocal); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp new file mode 100644 index 0000000000000..4b1e3b500cf3e --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rope(const NodeContext& context) { + num_inputs_check(context, 2, 3); + + int op_case = context.get_op_case(); + + ov::Output res; + + auto data_node = context.get_input(0).get_node_shared_ptr(); + auto output_shape = context.get_output_shape(0).to_shape(); + int32_t* op_params = context.get_output_op_params(0); + + Output cos_theta_node; + Output sin_theta_node; + if (context.has_input("rope_cos")) { + cos_theta_node = context.get_input("rope_cos"); + sin_theta_node = context.get_input("rope_sin"); + } else { + auto inp_pos = context.get_input(1).get_node_shared_ptr(); + std::shared_ptr rope_freqs_weight; + if (context.get_input_size() == 3) { + rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); + } + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + sin_theta_node = sin_cos.first; + cos_theta_node = sin_cos.second; + } + + if (op_case == 2) { + // The input comes from a VIEW + int slice_len = output_shape[1] * output_shape[2]; + data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr(); + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {3}, std::vector{-1, (int64_t) output_shape[1], (int64_t) output_shape[2]}); + data_node = std::make_shared(data_node, data_shape, false); + } + + const int mode = op_params[2]; + constexpr int ROPE_TYPE_NEOX = 2; + constexpr int ROPE_TYPE_NORM = 0; + + if (mode == ROPE_TYPE_NORM) { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]}); + auto even_slice = std::make_shared(data_node, zero, end, two, two); + auto odd_slice = std::make_shared(data_node, one, end, two, two); + + Output first_half = + std::make_shared(std::make_shared(even_slice, cos_theta_node), + std::make_shared(odd_slice, sin_theta_node)); + Output second_half = + std::make_shared(std::make_shared(even_slice, sin_theta_node), + std::make_shared(odd_slice, cos_theta_node)); + + first_half = std::make_shared(first_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {3})); + second_half = std::make_shared(second_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {3})); + auto stack = std::make_shared(OutputVector{first_half, second_half}, 3); + res = std::make_shared(stack, std::make_shared(data_node), false); + } else if (mode == ROPE_TYPE_NEOX) { + auto data_split = std::make_shared( + data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2); + Output slice_data_node_0 = data_split->outputs()[0]; + Output slice_data_node_1 = data_split->outputs()[1]; + + auto first_half_node = std::make_shared( + std::make_shared(slice_data_node_0, cos_theta_node), + std::make_shared(slice_data_node_1, sin_theta_node)); + + auto second_half_node = std::make_shared( + std::make_shared(slice_data_node_0, sin_theta_node), + std::make_shared(slice_data_node_1, cos_theta_node)); + + res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, 2); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp new file mode 100644 index 0000000000000..783440ebd967e --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp @@ -0,0 +1,29 @@ +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_scale(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + float scale; + memcpy(&scale, context.get_output_op_params(0), sizeof(float)); + auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + + auto res = std::make_shared(context.get_input(0), scale_node); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp new file mode 100644 index 0000000000000..50817c8323bef --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_set_rows(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + auto data = context.get_input(0); + data = std::make_shared(data, context.get_output_type(0)); + + auto dst_shape = context.get_output_shape(0).to_shape(); + FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS"); + + if (context.is_static() && context.is_first_token()) { + return rename_outputs_with_suffix({data}, context.get_name()); + } + + auto indices = context.get_input(1); + auto dst = context.get_input(context.get_output_name()); + + auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto dst_reshaped = std::make_shared( + dst, + ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}), + false); + auto indices_reshaped = + std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + auto data_reshaped = std::make_shared( + data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false); + + auto updated = std::make_shared(dst_reshaped, indices_reshaped, data_reshaped, zero); + auto res = std::make_shared(updated, std::make_shared(dst), false); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/softmax.cpp b/ggml/src/ggml-openvino/openvino/op/softmax.cpp new file mode 100644 index 0000000000000..1aa3bf76a06bb --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/softmax.cpp @@ -0,0 +1,88 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_soft_max(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + auto input_node = context.get_input(0).get_node_shared_ptr(); + ov::Output res; + + float scale = 1.0f; + float max_bias = 0.0f; + auto* op_params = context.get_output_op_params(0); + memcpy(&scale, (float*) op_params + 0, sizeof(float)); + memcpy(&max_bias, (float*) op_params + 1, sizeof(float)); + auto src0_shape = context.get_input_shape(0).get_shape(); + const uint32_t h = src0_shape[2]; + const uint32_t n_head = src0_shape[0]; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + auto scaled_input = std::make_shared(input_node, scale_node); + + if (context.get_input_size() < 2) { + res = std::make_shared(scaled_input, 2); + return rename_outputs_with_suffix({res}, context.get_name()); + } + + ov::Output mask_node_sliced; + if (context.has_input("KQ_mask_sliced")) { + mask_node_sliced = context.get_input("KQ_mask_sliced"); + } else { + auto token_len = get_dimensions(input_node, {1}); + auto mask_node = context.get_input(1); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_node_sliced = std::make_shared(mask_node, zero, token_len, one, one); + } + + if (mask_node_sliced.get_element_type() != context.get_output_type(0)) { + mask_node_sliced = std::make_shared(mask_node_sliced, context.get_output_type(0)); + } + + Output slope_mask; + if (slope != 1.0f) { + auto slope_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{slope}); + slope_mask = std::make_shared(mask_node_sliced, slope_node); + throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use."); + } + slope_mask = mask_node_sliced; + + auto input_slope_mask_node = std::make_shared(scaled_input, slope_mask); + + res = std::make_shared(input_slope_mask_node, 2); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp new file mode 100644 index 0000000000000..c585dffa6e1b9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -0,0 +1,23 @@ +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_transpose(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1})); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp new file mode 100644 index 0000000000000..2b27c0be1227c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp @@ -0,0 +1,27 @@ +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_silu(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto sigmoid = std::make_shared(input); + auto res = std::make_shared(input, sigmoid); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/view.cpp b/ggml/src/ggml-openvino/openvino/op/view.cpp new file mode 100644 index 0000000000000..034b6df119510 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/view.cpp @@ -0,0 +1,22 @@ +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_view(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + if (context.get_op_case() == 2) { + auto dst_shape = context.get_output_shape(0).to_shape(); + return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[1] * dst_shape[2])}, context.get_name()); + } + return {context.get_input(0)}; +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp new file mode 100644 index 0000000000000..e36e8f17cc94e --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -0,0 +1,46 @@ +#include "op_table.hpp" + +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +std::unordered_map get_supported_ops() { + using namespace ov::op; + return { + {"GGML_OP_ADD", op::translate_1to1_match_2_inputs }, + {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs }, + {"GGML_OP_CONT", op::translate_cont }, + {"GGML_OP_DIV", op::translate_1to1_match_2_inputs }, + {"GGML_OP_GET_ROWS", op::translate_get_rows }, + {"GGML_OP_MUL", op::translate_1to1_match_2_inputs}, + {"GGML_OP_MUL_MAT", op::translate_mulmat }, + {"GGML_OP_PERMUTE", op::translate_permute }, + {"GGML_OP_RESHAPE", op::translate_reshape }, + {"GGML_OP_RMS_NORM", op::translate_rms_norm }, + {"GGML_OP_ROPE", op::translate_rope }, + {"GGML_OP_SCALE", op::translate_scale }, + {"GGML_OP_SOFT_MAX", op::translate_soft_max }, + {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, + {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, + {"GGML_OP_VIEW", op::translate_view }, + {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, + {"GGML_GLU_OP_GEGLU", op::translate_glu_geglu }, + {"GGML_OP_SET_ROWS", op::translate_set_rows }, + {"GGML_OP_CPY", op::translate_cpy }, + {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext }, + }; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.hpp b/ggml/src/ggml-openvino/openvino/op_table.hpp new file mode 100644 index 0000000000000..5d4f0538604d1 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +namespace op { + +#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& context) + +GGML_OP_CONVERTER(translate_add); +GGML_OP_CONVERTER(translate_cont); +GGML_OP_CONVERTER(translate_get_rows); +GGML_OP_CONVERTER(translate_mul); +GGML_OP_CONVERTER(translate_mulmat); +GGML_OP_CONVERTER(translate_permute); +GGML_OP_CONVERTER(translate_reshape); +GGML_OP_CONVERTER(translate_rms_norm); +GGML_OP_CONVERTER(translate_rope); +GGML_OP_CONVERTER(translate_scale); +GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_soft_max); +GGML_OP_CONVERTER(translate_transpose); +GGML_OP_CONVERTER(translate_view); +GGML_OP_CONVERTER(translate_glu_swiglu); +GGML_OP_CONVERTER(translate_glu_geglu); +GGML_OP_CONVERTER(translate_set_rows); +GGML_OP_CONVERTER(translate_cpy); +GGML_OP_CONVERTER(translate_flash_attn_ext); + +} // namespace op + +std::unordered_map get_supported_ops(); + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp new file mode 100644 index 0000000000000..4759e86e1ea34 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp @@ -0,0 +1,117 @@ +#include "eliminate_zp.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +EliminateZeroPoints::EliminateZeroPoints() { + // Find pattern: + // (Multiply Any(scale) + // (Subtract (Convert Constant(data))) + // (Convert Constant(zero_point))) + // where zero_point is a scalar + // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val + // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant + + auto m_data_constant = ov::pass::pattern::wrap_type(); + auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); + + auto m_zp_constant = ov::pass::pattern::wrap_type(); + auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); + + auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); + auto m_scale = ov::pass::pattern::any_input(); + auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); + + const auto callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto multiply_node = std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); + auto subtract_node = std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); + auto data_constant = std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); + auto zp_constant = std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); + + if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { + return false; + } + + if (ov::shape_size(zp_constant->get_shape()) != 1) { + return false; + } + + auto data_type = data_constant->get_element_type(); + auto zp_data = zp_constant->cast_vector(); + + if (zp_data.empty()) { + return false; + } + + int zp_value = zp_data[0]; + + bool should_eliminate = false; + ov::element::Type target_type; + + if (data_type == ov::element::u4 && zp_value == 8) { + should_eliminate = true; + target_type = ov::element::i4; + } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { + should_eliminate = true; + target_type = ov::element::i8; + } + + if (!should_eliminate) { + return false; + } + + auto data_shape = data_constant->get_shape(); + size_t total_elements = ov::shape_size(data_shape); + + std::shared_ptr new_constant; + + // TODO improve performance + if (data_type == ov::element::u4) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } else if (data_type == ov::element::u8) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&, zp_value](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } + + auto new_convert = std::make_shared(new_constant, subtract_node->get_output_element_type(0)); + ov::replace_node(subtract_node, new_convert); + + return true; + }; + + register_matcher(std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp new file mode 100644 index 0000000000000..edd3cd718d9b0 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.hpp @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class EliminateZeroPoints : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") + EliminateZeroPoints(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp new file mode 100644 index 0000000000000..f38c0837d1374 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp @@ -0,0 +1,60 @@ +#include "fuse_to_sdpa.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +FuseToSDPA::FuseToSDPA() { + // Not maintained since FLASH_ATTN_EXT has replaced this pattern + const auto m_k = ov::pass::pattern::any_input(); + const auto m_q = ov::pass::pattern::any_input(); + const auto m_qk = ov::pass::pattern::wrap_type({m_q, m_k}); + const auto m_qk_f32 = ov::pass::pattern::wrap_type({m_qk}); + const auto m_scale = ov::pass::pattern::any_input(); + const auto m_scaled_qk = ov::pass::pattern::wrap_type({m_qk_f32, m_scale}); + const auto m_mask = ov::pass::pattern::any_input(); + const auto m_masked_qk = ov::pass::pattern::wrap_type({m_scaled_qk, m_mask}); + const auto m_softmax_qk = ov::pass::pattern::wrap_type({m_masked_qk}); + const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type({m_softmax_qk}); + const auto m_v = ov::pass::pattern::any_input(); + const auto m_qkv = ov::pass::pattern::wrap_type({m_softmax_qk_f16, m_v}); + + const auto callback = [=](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto k = pattern_to_output[m_k]; + auto q = pattern_to_output[m_q]; + auto v = pattern_to_output[m_v]; + auto mask = pattern_to_output[m_mask]; + auto scale = pattern_to_output[m_scale]; + + auto mask_f16 = register_new_node(mask, ov::element::f16); + auto scale_f16 = register_new_node(scale, ov::element::f16); + auto sdpa = std::make_shared(q, k, v, mask_f16, scale_f16, false); + + ov::replace_node(m.get_match_root(), sdpa); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa); + + return true; + }; + register_matcher(std::make_shared(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.hpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.hpp new file mode 100644 index 0000000000000..8b5164d232932 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.hpp @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class FuseToSDPA : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA") + FuseToSDPA(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.hpp b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.hpp new file mode 100644 index 0000000000000..b40eaf4205703 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "mark_decompression_convert_constant_folding.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "openvino/core/visibility.hpp" + +#ifdef OPENVINO_STATIC_LIBRARY +# define TRANSFORMATIONS_API +#else +# ifdef IMPLEMENT_OPENVINO_API +# define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS +# else +# define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS +# endif // IMPLEMENT_OPENVINO_API +#endif // OPENVINO_STATIC_LIBRARY + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API MarkCompressedFloatConstants; + +} // namespace pass +} // namespace ov + +class ov::pass::MarkCompressedFloatConstants : public MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants") + MarkCompressedFloatConstants(); +}; diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp new file mode 100644 index 0000000000000..944381968226d --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -0,0 +1,246 @@ +#include "translate_session.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-openvino/openvino/node_context.hpp" +#include "ggml-openvino/openvino/utils.hpp" +#include "input_model.hpp" +#include "pass/eliminate_zp.hpp" +#include "pass/mark_decompression_convert_constant_folding.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +using namespace ov::op; + +namespace { + +ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs( + const std::shared_ptr& model, const std::map& kv_param_res_names) { + ov::pass::MakeStateful::ParamResPairs pairs; + const auto& params = model->get_parameters(); + const auto& results = model->get_results(); + + for (const auto& param_res : kv_param_res_names) { + const auto& param_name = param_res.first; + const auto& res_name = param_res.second; + + auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr& node) { + return node->get_friendly_name() == param_name; + }); + + OPENVINO_ASSERT(param_it != params.end(), "The tensor name ", param_name, + " is not associated with any of " + "Parameters in the network."); + + auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr& node) { + return node->get_friendly_name() == res_name; + }); + + OPENVINO_ASSERT(res_it != results.end(), "The tensor name ", res_name, + " is not associated with any of " + "Results in the network."); + + std::shared_ptr param = *param_it; + std::shared_ptr res = *res_it; + pairs.emplace_back(param, res); + } + return pairs; +} + +void add_token_len(TensorMap& tensor_map) { + auto inp_tokens = tensor_map.at("inp_tokens").get_node_shared_ptr(); + auto token_len = get_dimensions(inp_tokens, {2}); + token_len->set_friendly_name("token_len"); + tensor_map.insert({"token_len", token_len->output(0)}); +} + +void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { + auto token_len = tensor_map.at("token_len").get_node_shared_ptr(); + + auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name, bool is_static) { + if (tensor_map.find(mask_name) != tensor_map.end()) { + auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); + std::shared_ptr mask_sliced; + if (is_static) { + mask_sliced = mask; + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_sliced = std::make_shared(mask, zero, token_len, one, one); + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } + tensor_map.insert({sliced_name, mask_sliced->output(0)}); + } + }; + + create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static()); + create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); +} + +void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { + int32_t* rope_params = ggml_model_decoder.get_rope_params(); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + std::shared_ptr rope_freqs_weight; + if (tensor_map.find("rope_freqs_weight") != tensor_map.end()) { + rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr(); + } + + auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight); + auto sin_theta = sin_cos.first; + auto cos_theta = sin_cos.second; + + cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos"); + sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin"); + tensor_map.insert({"rope_cos", cos_theta}); + tensor_map.insert({"rope_sin", sin_theta}); +} + +// Create common patterns +void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { + add_token_len(tensor_map); + add_sliced_mask(tensor_map, ggml_model_decoder); + add_rope_sin_cos(tensor_map, ggml_model_decoder); +} + +} // namespace + +TranslateSession::TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map& translator_map, + bool naive) : + m_input_model(input_model), + m_translator_map(translator_map), + m_ov_model(nullptr), + m_naive(naive) {} + +std::shared_ptr TranslateSession::get_converted_model() { + if (m_ov_model) { + return m_ov_model; + } + m_ov_model = translate_graph(m_input_model); + return m_ov_model; +} + +std::shared_ptr TranslateSession::translate_graph(const frontend::InputModel::Ptr& input_model) { + ov::ParameterVector params; + ov::ResultVector results; + auto tensor_map = std::make_shared(); + std::shared_ptr resulting_model; + + const auto& ggml_model = std::dynamic_pointer_cast(input_model); + std::shared_ptr ggml_model_decoder = ggml_model->get_model_decoder(); + + for (const auto& it : ggml_model_decoder->get_model_inputs()) { + params.push_back(std::dynamic_pointer_cast(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto& it : ggml_model_decoder->get_model_extra_inputs()) { + params.push_back(std::dynamic_pointer_cast(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto& it : ggml_model_decoder->get_model_weights()) { + (*tensor_map)[it.first] = it.second; + } + + auto node_visitor = [&](std::shared_ptr node) { + auto operation_type = node->get_op_type(); + if (operation_type == "GGML_OP_NONE") { + return; + } + + ov::OutputVector converted_outputs; + auto it = m_translator_map.find(operation_type); + FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), + "Translation for operation type ", + operation_type, + " is not implemented."); + NodeContext node_context(node, tensor_map, this); + converted_outputs = it->second(node_context); + + const auto& node_output_names = node->get_output_names(); + FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), + "Number of ", + operation_type, + " outputs greater than number of converted outputs, which are ", + node_output_names.size(), + " and ", + converted_outputs.size(), + " respectively."); + + for (size_t i = 0; i < node_output_names.size(); ++i) { + auto output_name = node_output_names[i]; + if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) { + (*tensor_map)[output_name] = converted_outputs[i]; + } + } + }; + + if (!m_naive) { + preprocess(*tensor_map, *ggml_model_decoder); + } + ggml_model_decoder->visit_subgraph(node_visitor); + + for (const auto& name : ggml_model_decoder->get_model_output_names()) { + FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(), + "Output name not found in tensor map: ", + name); + auto result = std::make_shared(tensor_map->at(name)); + result->set_friendly_name(name); + results.push_back(result); + } + + resulting_model = std::make_shared(results, params); + + apply_transformations(resulting_model); + return resulting_model; +} + +std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr model) { + auto ggml_model_decoder = std::dynamic_pointer_cast(m_input_model)->get_model_decoder(); + { + ov::pass::Manager manager; + manager.set_per_pass_validation(true); + manager.register_pass(); + + if (!ggml_model_decoder->is_static()) { + const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); + const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); + manager.register_pass(kv_param_res_pairs); + } + + // if (ggml_model_decoder->is_static()) { + manager.register_pass(); + // } + manager.run_passes(model); + } + return model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.hpp b/ggml/src/ggml-openvino/openvino/translate_session.hpp new file mode 100644 index 0000000000000..7072d4a9e8b1a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "input_model.hpp" +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession { +public: + TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map& translator_map, bool naive = false); + + std::shared_ptr get_converted_model(); + std::shared_ptr translate_graph(const frontend::InputModel::Ptr& input_model); + +private: + std::shared_ptr apply_transformations(std::shared_ptr model); + const frontend::InputModel::Ptr m_input_model; + const std::unordered_map& m_translator_map; + std::shared_ptr m_ov_model; + bool m_naive; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp new file mode 100644 index 0000000000000..f70cb91a17fe0 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -0,0 +1,205 @@ +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-impl.h" + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime() { + std::time_t now = std::time(nullptr); + char buf[100]; + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); + return buf; +} + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) { + auto input_size = context.get_input_size(); + FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, "Got more inputs than expected"); +} + +int non_cont_dim(std::vector ne, std::vector nb) { + int dim = nb.size() - 1; + size_t bytes = nb[dim]; + for (int i = dim; i > 0; i--) { + bytes *= ne[i]; + if (bytes != nb[i - 1]) { + return i; + } + } + return 0; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix) { + for (const auto& output : outputs) { + auto node = output.get_node_shared_ptr(); + std::string name = node->get_friendly_name(); + name += "_"; + name += suffix; + node->set_friendly_name(name); + // std::cout << name << " " << output.get_partial_shape() << std::endl; + } + return outputs; +} + +namespace { +ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], float ext_factor) { + int half_n_dims = n_dims / 2; + std::vector dim_ids_vec(half_n_dims); + std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0); + auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, (size_t) half_n_dims}, dim_ids_vec); + auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {corr_dims[0]}); + auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {corr_dims[1]}); + auto denom = + std::make_shared(std::make_shared(corr_high, corr_low), + ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {0.001f})); + auto ramp_y = + std::make_shared(std::make_shared(dim_ids, corr_low), denom); + auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); + auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + return ramp_mix; +} + +float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { +#ifndef M_PI +# define M_PI 3.14159265358979323846 +#endif + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float dims[2]) { + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = std::max(0.0f, start); + dims[1] = std::min(static_cast(n_dims - 1), end); +} +} // namespace + +std::pair, ov::Output> make_sin_cos(int32_t* rope_params, + std::shared_ptr inp_pos, + std::shared_ptr rope_freqs_weight) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); + inp_pos = std::make_shared(inp_pos, pos_perm); + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + const int n_dims = rope_params[1]; + const int n_ctx_orig = rope_params[4]; + memcpy(&freq_base, rope_params + 5, sizeof(float)); + memcpy(&freq_scale, rope_params + 6, sizeof(float)); + memcpy(&ext_factor, rope_params + 7, sizeof(float)); + memcpy(&attn_factor, rope_params + 8, sizeof(float)); + memcpy(&beta_fast, rope_params + 9, sizeof(float)); + memcpy(&beta_slow, rope_params + 10, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector factor(n_dims / 2); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } + + Output freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); + } + + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + Output theta; + float mscale = attn_factor; + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } + + Output cos_theta = std::make_shared(theta); + Output sin_theta = std::make_shared(theta); + + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + return std::make_pair(sin_theta, cos_theta); +} + +ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len) { + // Only works for VIEW operations that slice at the lowest dimension + // If the VIEW also reshape the result, `slice_len` should be provided + auto input = context.get_input(input_index); + int32_t* op_params = context.get_input_op_params(input_index); + auto src1_stride = context.get_input_stride(input_index); + + int64_t split_addr = op_params[0] / src1_stride[2]; + if (slice_len == 0) { + slice_len = context.get_input_shape(input_index)[2].get_length(); + } + int64_t slice_end = split_addr + slice_len; + + auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end}); + auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto sliced = std::make_shared(input, begin, end, stride, axes); + return sliced; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.hpp b/ggml/src/ggml-openvino/openvino/utils.hpp new file mode 100644 index 0000000000000..6c6d2ae8d4f23 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.hpp @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime(); + +void dump_ov_model(std::shared_ptr model); + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); + +int non_cont_dim(std::vector ne, std::vector nb); + +template +std::vector argsort_descend(const std::vector& v) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +template +std::vector sorted_descend(std::vector v) { + std::sort(v.begin(), v.end(), [](T a, T b) { + return a > b; + }); + return v; +} + +template +bool is_permuted(const std::vector& strides) { + for (size_t i = 0; i < strides.size() - 1; ++i) { + if (strides[i] < strides[i + 1]) { + return true; + } + } + return false; +} + +template +std::vector permute(const std::vector& x, const std::vector& perm) { + std::vector result; + result.reserve(perm.size()); + for (int i : perm) { + result.push_back(x[i]); + } + return result; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); + +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix); + +std::pair, ov::Output> make_sin_cos(int32_t* rope_params, + std::shared_ptr inp_pos, + std::shared_ptr rope_freqs_weight = nullptr); + +ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); + +namespace op { +template +OutputVector translate_1to1_match_2_inputs(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto res = std::make_shared(context.get_input(0), context.get_input(1)); + return rename_outputs_with_suffix({res}, context.get_name()); +} +} // namespace op + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp new file mode 100644 index 0000000000000..0ec815f07f4f9 --- /dev/null +++ b/ggml/src/ggml-openvino/utils.cpp @@ -0,0 +1,507 @@ +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml-impl.h" +#include "ggml-openvino/ggml-decoder.h" +#include "ggml.h" +#include "openvino/frontend.hpp" +#include "openvino/input_model.hpp" + +ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string& name) { + const auto* ggml_tensor = ggml_decoder->get_input_ggml_tensor(name); + auto* input_data = ggml_tensor->data; + ov::Shape input_shape; + if (name.find("cache_k") == 0 || name.find("cache_v") == 0) { + input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor).to_shape(); + } else if (ggml_tensor->op == GGML_OP_VIEW) { + // This case is added to make test-backend-ops work + input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor->view_src).to_shape(); + } else { + input_shape = ggml_decoder->get_input_shape(name).to_shape(); + } + auto input_tensor = ov::Tensor(ggml_decoder->get_input_type(name), input_shape, input_data); + return input_tensor; +} + +std::map get_ggml_graph_output_dst(std::shared_ptr ggml_decoder) { + std::map output_tensors; + auto output_names = ggml_decoder->get_model_output_names(); + for (size_t inp = 0; inp < output_names.size(); ++inp) { + auto name = output_names[inp]; + const auto* tensor = ggml_decoder->get_output_ggml_tensor(name); + auto* output_data = tensor->view_src ? tensor->view_src->data : tensor->data; + output_tensors[name] = output_data; + } + return output_tensors; +} + +static ov::frontend::FrontEnd::Ptr get_ggml_frontend() { + auto fem = ov::frontend::FrontEndManager(); + auto front_end = fem.load_by_framework("ggml"); + return front_end; +} + +enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph* cgraph) { + static ov::Core core; + + static std::string device = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : ""; + if (device.empty()) { + const std::vector preferred_device = { "GPU", "CPU", "NPU" }; + const auto available_devices = core.get_available_devices(); + for (const auto& dev : preferred_device) { + if (std::find(available_devices.begin(), available_devices.end(), dev) != available_devices.end()) { + device = dev; + break; + } + } + } + + bool is_static = device == "NPU" ? true : false; + ov::AnyMap config; + if (device == "GPU") { + config = { + {"GPU_ENABLE_SDPA_OPTIMIZATION", "0"} + }; + } + + if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { + std::string filename = "cgraph.txt"; + GgmlOvDecoder::dump_cgraph(cgraph, filename); + } + + if (is_naive(cgraph)) { + return naive_compute(cgraph, core, device, config); + } + + auto start_time = ggml_time_us(); + + auto* cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); + if (cache_dir && !is_static) { + core.set_property(ov::cache_dir(cache_dir)); + } + + static std::mutex cache_mutex; + static std::unordered_map> infer_request_cache; + static std::unordered_map> ov_input_names_cache; + static std::unordered_map> ov_output_names_cache; + // For NPU, store the kvcache model, since we cannot create two infer_request + static std::unordered_map compiled_model_cache; + + std::shared_ptr ggml_decoder; + ov::InferRequest infer_request; + + int64_t decoder_end_time; + int64_t conversion_end_time; + int64_t compile_end_time; + + { + std::lock_guard lock(cache_mutex); + + auto it = infer_request_cache.find(cgraph); + if (it != infer_request_cache.end()) { + std::map> model_weights; + ggml_decoder = std::make_shared(cgraph, model_weights, is_static, false); + decoder_end_time = ggml_time_us(); + + // For NPU for the first time we call kvcache modle, pop the compiled kvcache model from cache + if (is_static && compiled_model_cache.find(cgraph) != compiled_model_cache.end()) { + infer_request_cache[cgraph] = + std::make_shared(compiled_model_cache[cgraph].create_infer_request()); + compiled_model_cache.erase(cgraph); + } + infer_request = *infer_request_cache[cgraph]; + + conversion_end_time = ggml_time_us(); + compile_end_time = conversion_end_time; + } else { + std::shared_ptr model; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, get_types_to_requant(device)); + + if (is_static) { + ggml_decoder = std::make_shared(cgraph, model_weights, is_static, true); + auto ggml_decoder_kvcache = std::make_shared(cgraph, model_weights, is_static, false); + decoder_end_time = ggml_time_us(); + + auto input_model = std::make_shared(ggml_decoder); + auto input_model_kvcache = std::make_shared(ggml_decoder_kvcache); + + model = ov::frontend::ggml::FrontEnd::convert(input_model); + ggml_decoder->clear_model_weights(); + auto model_kvcache = ov::frontend::ggml::FrontEnd::convert(input_model_kvcache); + ggml_decoder_kvcache->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp); + ov::serialize(model, timestamped_filename); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_kvcache_%lld.xml", timestamp); + ov::serialize(model_kvcache, timestamped_filename); + } + + auto compiled_model = core.compile_model(model, device, get_npu_prefill_config()); + auto compiled_model_kvcache = core.compile_model(model_kvcache, device, get_npu_generate_config()); + compiled_model_cache[cgraph] = compiled_model_kvcache; + compile_end_time = ggml_time_us(); + + infer_request_cache[cgraph] = std::make_shared(compiled_model.create_infer_request()); + infer_request = *infer_request_cache[cgraph]; + compiled_model_cache[cgraph] = compiled_model_kvcache; + } else { + ggml_decoder = std::make_shared(cgraph, model_weights, is_static, true); + decoder_end_time = ggml_time_us(); + + auto input_model = std::make_shared(ggml_decoder); + model = ov::frontend::ggml::FrontEnd::convert(input_model); + ggml_decoder->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp); + ov::serialize(model, timestamped_filename); + } + + auto compiled_model = core.compile_model(model, device, config); + compile_end_time = ggml_time_us(); + infer_request_cache[cgraph] = std::make_shared(compiled_model.create_infer_request()); + infer_request = *infer_request_cache[cgraph]; + } + + std::vector ov_input_names; + std::vector ov_output_names; + for (const auto& ov_param : model->get_parameters()) { + ov_input_names.push_back(ov_param->get_friendly_name()); + } + for (const auto& ov_output : model->get_results()) { + ov_output_names.push_back(ov_output->get_friendly_name()); + } + ov_input_names_cache[cgraph] = ov_input_names; + ov_output_names_cache[cgraph] = ov_output_names; + } + } + + auto ov_input_names = ov_input_names_cache[cgraph]; + auto ov_output_names = ov_output_names_cache[cgraph]; + for (size_t i = 0; i < ov_input_names.size(); i++) { + auto param_name = ov_input_names[i]; + auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name); + infer_request.set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + print_input_tensor_info(param_name, input_tensor); + } + } + auto input_end_time = ggml_time_us(); + + infer_request.infer(); + auto infer_end_time = ggml_time_us(); + + auto gguf_tensor_addrs = get_ggml_graph_output_dst(ggml_decoder); + for (size_t i = 0; i < ov_output_names.size(); i++) { + auto& result_name = ov_output_names[i]; + const auto output_tensor = infer_request.get_output_tensor(i); + + std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size()); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + print_output_tensor_info(result_name, output_tensor, gguf_tensor_addrs); + } + } + auto end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_PROFILING")) { + GGML_LOG_INFO("GGML OpenVINO Backend: \n"); + GGML_LOG_INFO(" - Graph decoder Time: %ld ms \n", (decoder_end_time - start_time) / 1000); + GGML_LOG_INFO(" - Graph conversion Time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile Time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + GGML_LOG_INFO(" - Graph Input Time: %ld ms \n", (input_end_time - compile_end_time) / 1000); + GGML_LOG_INFO(" - Graph Inference Time: %ld ms \n", (infer_end_time - input_end_time) / 1000); + GGML_LOG_INFO(" - Graph Output Time: %ld ms \n", (end_time - infer_end_time) / 1000); + } + + return GGML_STATUS_SUCCESS; + GGML_UNUSED(backend); +} + +namespace { +ov::AnyMap get_npu_base_config() { + return { + {"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add_RMSNorm" }, + {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, + {"NPU_USE_NPUW", "YES" }, + {"NPUW_DEVICES", "NPU" }, + {"NPUW_FOLD", "YES" }, + {"NPUW_WEIGHTS_BANK", "shared" }, + {"NPUW_FUNCALL_FOR_ALL", "YES" }, + {"NPUW_FUNCALL_ASYNC", "YES" }, + {"NPUW_DQ", "YES" }, + {"NPUW_DQ_FULL", "NO" }, + {"NPUW_CACHE_DIR", getenv("GGML_OPENVINO_CACHE_DIR") ? getenv("GGML_OPENVINO_CACHE_DIR") : ""}, + }; +} +} // namespace + +ov::AnyMap get_npu_prefill_config() { + auto config = get_npu_base_config(); + return config; +} + +ov::AnyMap get_npu_generate_config() { + auto config = get_npu_base_config(); + return config; +} + +std::map get_types_to_requant(const std::string& device) { + if (device == "NPU") { + return { + {GGML_TYPE_Q4_0, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_1, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q4_K, ExtraQuantType::Q4_0_128}, + {GGML_TYPE_Q6_K, ExtraQuantType::F16 }, + {GGML_TYPE_Q5_K, ExtraQuantType::F16 }, + }; + } + if (device == "GPU") { + return { + // gs16 is WIP + {GGML_TYPE_Q6_K, ExtraQuantType::Q8_0_32}, + }; + } + return {}; +} + +bool is_naive(struct ggml_cgraph* cgraph) { + constexpr int naive_graph_size_threshold = 20; + return cgraph->n_nodes < naive_graph_size_threshold; +} + +enum ggml_status naive_compute(struct ggml_cgraph* cgraph, + ov::Core& core, + const std::string& device, + const ov::AnyMap& config) { + if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) { + return GGML_STATUS_SUCCESS; + } + if (cgraph->nodes[0]->op == GGML_OP_FLASH_ATTN_EXT) { + return GGML_STATUS_FAILED; + } + + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + auto decoder = std::make_shared(cgraph, model_weights); + auto input_model = std::make_shared(decoder); + auto naive = true; + auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive); + if (getenv("GGML_OPENVINO_DUMP_IR")) { + ov::serialize(model, "IR_naive.xml"); + } + auto infer_request = core.compile_model(model, device, config).create_infer_request(); + + auto ov_params = model->get_parameters(); + for (size_t i = 0; i < ov_params.size(); i++) { + auto param_name = ov_params[i]->get_friendly_name(); + auto input_tensor = get_ov_input_tensor(decoder, param_name); + infer_request.set_input_tensor(i, input_tensor); + } + + infer_request.infer(); + + auto gguf_tensor_addrs = get_ggml_graph_output_dst(decoder); + auto ov_results = model->get_results(); + for (size_t i = 0; i < ov_results.size(); i++) { + auto result_name = ov_results[i]->get_friendly_name(); + const auto output_tensor = infer_request.get_output_tensor(i); + + std::memcpy(gguf_tensor_addrs[result_name], output_tensor.data(), output_tensor.get_byte_size()); + } + return GGML_STATUS_SUCCESS; +} + +ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string& param_name) { + bool is_static = ggml_decoder->is_static(); + bool is_first_token = ggml_decoder->is_first_token(); + + ov::Tensor input_tensor; + if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) { + input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name); + + } else if (!is_static) { + input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); + + } else { + if (param_name == "inp_tokens" || param_name == "inp_pos") { + if (is_first_token) { + size_t context_size = ggml_decoder->get_context_size(); + const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name); + std::vector padded_data = pad_input(input_tensor_ggml, 1, context_size, 0); + input_tensor = ov::Tensor(ov::element::i32, ov::Shape{1, 1, context_size}); + auto* data_ptr = input_tensor.data(); + std::copy(padded_data.begin(), padded_data.end(), data_ptr); + } else { + input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); + } + + } else if (param_name.find("KQ_mask") == 0) { + size_t context_size = ggml_decoder->get_context_size(); + const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name); + if (is_first_token) { + std::vector padded_data = + pad_input(input_tensor_ggml, context_size, context_size, -INFINITY); + set_zero_diagonal(padded_data, context_size); + input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, context_size, context_size}); + auto* data_ptr = input_tensor.data(); + std::copy(padded_data.begin(), padded_data.end(), data_ptr); + } else { + std::vector padded_data = pad_input(input_tensor_ggml, 1, context_size, -INFINITY); + input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, 1, context_size}); + auto* data_ptr = input_tensor.data(); + std::copy(padded_data.begin(), padded_data.end(), data_ptr); + } + + } else if (const auto* op = ggml_decoder->get_tensor_used_op(ggml_decoder->get_tensor_from_name(param_name)); + op && op->op == GGML_OP_SET_ROWS && is_static && is_first_token) { + input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1, 1, 1}); + } else { + input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); + } + } + return input_tensor; +} + +size_t checksum(const void* data, size_t size) { + const uint8_t* bytes = static_cast(data); + size_t sum = 0; + for (size_t i = 0; i < size; ++i) { + sum += (uint8_t) i; + sum += bytes[i]; + } + return sum; +} + +// Suppress deprecation warning for ov::Tensor::data() +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +void print_input_tensor_info(const std::string& name, const ov::Tensor& tensor) { + std::cout << "Input name: " << name << ", Input shape: " << tensor.get_shape() << ", Address: " << tensor.data() + << std::endl; + switch (tensor.get_element_type()) { + case ov::element::f32: + std::cout << *(tensor.data()) << std::endl; + break; + case ov::element::f16: + std::cout << *(tensor.data()) << std::endl; + break; + case ov::element::i32: + for (size_t i = 0; i < tensor.get_size(); ++i) { + std::cout << tensor.data()[i] << " "; + } + std::cout << std::endl; + break; + case ov::element::i64: + std::cout << *(tensor.data()) << std::endl; + break; + default: + break; + } +} + +void print_output_tensor_info(const std::string& name, const ov::Tensor& tensor, + std::map& output_dst) { + std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() + << ", Address: " << output_dst[name] << std::endl; + + auto print_float_stats = [](const std::string& type_name, size_t size, auto get_value) { + if (size == 0) { + return; + } + + float first = get_value(0); + float min = first; + float max = first; + double sum = first; + + for (size_t i = 1; i < size; ++i) { + float v = get_value(i); + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + sum += v; + } + double mean = sum / size; + + std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << "First" << std::setw(12) + << "Min" << std::setw(12) << "Max" << std::setw(12) << "Mean" << std::endl; + std::cout << std::right << std::setw(6) << "" << std::right << std::setw(12) << first << std::setw(12) << min + << std::setw(12) << max << std::setw(12) << mean << std::endl; + }; + + switch (tensor.get_element_type()) { + case ov::element::f32: { + const float* data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f32]", size, [data](size_t i) { return data[i]; }); + break; + } + case ov::element::f16: { + const ov::float16* data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f16]", size, [data](size_t i) { return static_cast(data[i]); }); + break; + } + default: + break; + } +} + +#pragma GCC diagnostic pop + +void set_zero_diagonal(std::vector& matrix, size_t dim) { + for (size_t i = 0; i < dim; ++i) { + matrix[i * dim + i] = 0.0f; + } +} + +bool is_prefill(struct ggml_cgraph* cgraph) { + for (int i = 0; i < cgraph->n_nodes; ++i) { + auto* op = cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; ++j) { + auto* src = op->src[j]; + if (src == nullptr) { + break; + } + if (std::string(src->name) == "inp_tokens") { + return src->ne[0] != 1; + } + } + } + GGML_LOG_ERROR("is_prefill: inp_tokens not found in cgraph"); + throw std::runtime_error("is_prefill: inp_tokens not found in cgraph"); +} diff --git a/ggml/src/ggml-openvino/utils.cpp.bak b/ggml/src/ggml-openvino/utils.cpp.bak new file mode 100644 index 0000000000000..8fef1985f91ae --- /dev/null +++ b/ggml/src/ggml-openvino/utils.cpp.bak @@ -0,0 +1,72 @@ +void model_cut() { + ov::Core core; + std::shared_ptr model = + core.read_model("/home/zijun/dev/llama.cpp-ov/tmp/fold_graph/Model1_01_0x5555601c5ac0.xml"); + + ov::ParameterVector new_params; + + auto ops = model->get_ops(); + std::shared_ptr node_a; + std::shared_ptr node_b; + for (const auto& op : ops) { + if (op->get_friendly_name() == "Multiply_4636_ffn_norm-0") { + node_a = op; + } else if (op->get_friendly_name() == "Multiply_4645_ffn_gate_par-0") { + node_b = op; + } else if (op->get_friendly_name() == "Parameter_39914") { + auto param = std::dynamic_pointer_cast(op); + new_params.push_back(param); + } else if (op->get_friendly_name() == "Parameter_39915") { + auto param = std::dynamic_pointer_cast(op); + new_params.push_back(param); + } + } + + auto subgraph_input_tensor = node_a->output(0); + auto subgraph_output_tensor = node_b->output(0); + + auto new_input = std::make_shared(subgraph_input_tensor.get_element_type(), + subgraph_input_tensor.get_shape()); + new_input->set_friendly_name("subgraph_input"); + new_params.push_back(new_input); + + // Rewire: replace all consumers of original tensor with new input + subgraph_input_tensor.replace(new_input); + + auto result = std::make_shared(subgraph_output_tensor); + result->set_friendly_name("subgraph_output"); + + auto subgraph = std::make_shared(ov::ResultVector{result}, new_params, "trimmed_subgraph"); + + ov::serialize(subgraph, "/home/zijun/dev/llama.cpp-ov/tmp/subgraph.xml"); + + assert(false); +} + +void create_graph() { + // Input shapes: [256, 1, 1] + ov::Shape input_shape{256, 1, 1}; + + // Define input parameters + auto input0 = std::make_shared(ov::element::f32, input_shape); + auto input1 = std::make_shared(ov::element::f32, input_shape); + + // Concat on axis 2 -> shape becomes [256, 1, 2] + auto concat = std::make_shared(ov::OutputVector{input0, input1}, 2); + + // Target shape constant for reshape: [256, 2] + auto reshape_shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {256, 2}); + + // special_zero = false + auto reshape = std::make_shared(concat, reshape_shape, false); + + // Define result node + auto result = std::make_shared(reshape); + + // Create model + auto model = std::make_shared(ov::ResultVector{result}, ov::ParameterVector{input0, input1}, "ReshapeConcatModel"); + + ov::serialize(subgraph, "/home/zijun/dev/llama.cpp-ov/tmp/subgraph3.xml"); + + exit(0); +} diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h new file mode 100644 index 0000000000000..42686c593b3ce --- /dev/null +++ b/ggml/src/ggml-openvino/utils.h @@ -0,0 +1,53 @@ +#include +#include + +#include "ggml-backend-impl.h" +#include "ggml-decoder.h" +#include "ggml-impl.h" + +enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph* cgraph); + +std::shared_ptr get_ggml_decoder(struct ggml_cgraph* cgraph, bool is_static, bool is_first_token); + +ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string& name); + +std::map get_ggml_graph_output_dst(std::shared_ptr ggml_decoder); + +size_t checksum(const void* data, size_t size); + +void print_input_tensor_info(const std::string& name, const ov::Tensor& tensor); + +void print_output_tensor_info(const std::string& name, + const ov::Tensor& tensor, + std::map& output_dst); + +template +std::vector pad_input(const ggml_tensor* tensor, size_t padded_rows, size_t padded_cols, T pad_value) { + std::vector padded_data(padded_rows * padded_cols, pad_value); + size_t rows = tensor->ne[1]; + size_t cols = tensor->ne[0]; + T* data = static_cast(tensor->data); + + for (size_t i = 0; i < std::min(rows, padded_rows); ++i) { + for (size_t j = 0; j < std::min(cols, padded_cols); ++j) { + padded_data[i * padded_cols + j] = data[i * cols + j]; + } + } + return padded_data; +} + +void set_zero_diagonal(std::vector& matrix, size_t dim); + +bool is_prefill(struct ggml_cgraph * cgraph); + +ov::AnyMap get_npu_prefill_config(); +ov::AnyMap get_npu_generate_config(); + +std::map get_types_to_requant(const std::string& device); + +ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string& param_name); + +bool is_naive(struct ggml_cgraph* cgraph); + +enum ggml_status naive_compute(struct ggml_cgraph* cgraph, ov::Core& core, const std::string& device, + const ov::AnyMap& config); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60a4f..5dcdd2c230148 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1093,7 +1093,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); + cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); res->t_tokens = inp->tokens; @@ -1141,6 +1141,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { auto & cur = inp->pos; cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd()); + cb(cur, "inp_pos", -1); ggml_set_input(cur); res->add_input(std::move(inp)); @@ -1176,6 +1177,7 @@ ggml_tensor * llm_graph_context::build_inp_out_ids() const { auto & cur = inp->out_ids; cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); + cb(cur, "inp_out_ids", -1); ggml_set_input(cur); res->add_input(std::move(inp)); @@ -1420,6 +1422,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + cb(inp->kq_mask, "KQ_mask", -1); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1466,7 +1469,7 @@ ggml_tensor * llm_graph_context::build_attn( } if (wo_b) { - //cb(cur, "kqv_wo", il); + cb(cur, "kqv_wo", il); } if (wo_b) { @@ -1496,6 +1499,7 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_name(inp->self_kq_mask, "KQ_mask"); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1622,7 +1626,7 @@ ggml_tensor * llm_graph_context::build_attn( } if (wo_b) { - //cb(cur, "kqv_wo", il); + cb(cur, "kqv_wo", il); } if (wo_b) { @@ -1677,7 +1681,7 @@ ggml_tensor * llm_graph_context::build_attn( } if (wo_b) { - //cb(cur, "kqv_wo", il); + cb(cur, "kqv_wo", il); } if (wo_b) { @@ -1704,6 +1708,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_name(inp->self_kq_mask, "KQ_mask"); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1718,6 +1723,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + ggml_set_name(inp->self_kq_mask_swa, "KQ_mask_swa"); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f4ce..677d4e01d8c60 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -201,7 +201,9 @@ if (NOT LLAMA_SANITIZE_ADDRESS) llama_build_and_test(test-opt.cpp) endif() llama_build_and_test(test-gguf.cpp) -llama_build_and_test(test-backend-ops.cpp) +if (NOT GGML_OPENVINO) + llama_build_and_test(test-backend-ops.cpp) +endif() llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model")