diff --git a/.github/actions/custom-build-and-push/action.yml b/.github/actions/custom-build-and-push/action.yml new file mode 100644 index 00000000000..48344237059 --- /dev/null +++ b/.github/actions/custom-build-and-push/action.yml @@ -0,0 +1,76 @@ +name: 'Build and Push Docker Image with Retry' +description: 'Attempts to build and push a Docker image, with a retry on failure' +inputs: + context: + description: 'Build context' + required: true + file: + description: 'Dockerfile location' + required: true + platforms: + description: 'Target platforms' + required: true + pull: + description: 'Always attempt to pull a newer version of the image' + required: false + default: 'true' + push: + description: 'Push the image to registry' + required: false + default: 'true' + load: + description: 'Load the image into Docker daemon' + required: false + default: 'true' + tags: + description: 'Image tags' + required: true + cache-from: + description: 'Cache sources' + required: false + cache-to: + description: 'Cache destinations' + required: false + retry-wait-time: + description: 'Time to wait before retry in seconds' + required: false + default: '5' + +runs: + using: "composite" + steps: + - name: Build and push Docker image (First Attempt) + id: buildx1 + uses: docker/build-push-action@v5 + continue-on-error: true + with: + context: ${{ inputs.context }} + file: ${{ inputs.file }} + platforms: ${{ inputs.platforms }} + pull: ${{ inputs.pull }} + push: ${{ inputs.push }} + load: ${{ inputs.load }} + tags: ${{ inputs.tags }} + cache-from: ${{ inputs.cache-from }} + cache-to: ${{ inputs.cache-to }} + + - name: Wait to retry + if: steps.buildx1.outcome != 'success' + run: | + echo "First attempt failed. Waiting ${{ inputs.retry-wait-time }} seconds before retry..." + sleep ${{ inputs.retry-wait-time }} + shell: bash + + - name: Build and push Docker image (Retry Attempt) + if: steps.buildx1.outcome != 'success' + uses: docker/build-push-action@v5 + with: + context: ${{ inputs.context }} + file: ${{ inputs.file }} + platforms: ${{ inputs.platforms }} + pull: ${{ inputs.pull }} + push: ${{ inputs.push }} + load: ${{ inputs.load }} + tags: ${{ inputs.tags }} + cache-from: ${{ inputs.cache-from }} + cache-to: ${{ inputs.cache-to }} diff --git a/.github/workflows/pr-helm-chart-testing.yml.disabled.txt b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt new file mode 100644 index 00000000000..7c4903a07f7 --- /dev/null +++ b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt @@ -0,0 +1,67 @@ +# This workflow is intentionally disabled while we're still working on it +# It's close to ready, but a race condition needs to be fixed with +# API server and Vespa startup, and it needs to have a way to build/test against +# local containers + +name: Helm - Lint and Test Charts + +on: + merge_group: + pull_request: + branches: [ main ] + +jobs: + lint-test: + runs-on: Amd64 + + # fetch-depth 0 is required for helm/chart-testing-action + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Helm + uses: azure/setup-helm@v4.2.0 + with: + version: v3.14.4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + cache-dependency-path: | + backend/requirements/default.txt + backend/requirements/dev.txt + backend/requirements/model_server.txt + - run: | + python -m pip install --upgrade pip + pip install -r backend/requirements/default.txt + pip install -r backend/requirements/dev.txt + pip install -r backend/requirements/model_server.txt + + - name: Set up chart-testing + uses: helm/chart-testing-action@v2.6.1 + + - name: Run chart-testing (list-changed) + id: list-changed + run: | + changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }}) + if [[ -n "$changed" ]]; then + echo "changed=true" >> "$GITHUB_OUTPUT" + fi + + - name: Run chart-testing (lint) +# if: steps.list-changed.outputs.changed == 'true' + run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }} + + - name: Create kind cluster +# if: steps.list-changed.outputs.changed == 'true' + uses: helm/kind-action@v1.10.0 + + - name: Run chart-testing (install) +# if: steps.list-changed.outputs.changed == 'true' + run: ct install --all --config ct.yaml +# run: ct install --target-branch ${{ github.event.repository.default_branch }} + \ No newline at end of file diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml new file mode 100644 index 00000000000..00b92c9b003 --- /dev/null +++ b/.github/workflows/pr-python-connector-tests.yml @@ -0,0 +1,57 @@ +name: Connector Tests + +on: + pull_request: + branches: [main] + schedule: + # This cron expression runs the job daily at 16:00 UTC (9am PT) + - cron: "0 16 * * *" + +env: + # Confluence + CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }} + CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }} + CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }} + CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }} + CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }} + CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} + +jobs: + connectors-check: + runs-on: ubuntu-latest + + env: + PYTHONPATH: ./backend + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + cache: "pip" + cache-dependency-path: | + backend/requirements/default.txt + backend/requirements/dev.txt + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + pip install -r backend/requirements/default.txt + pip install -r backend/requirements/dev.txt + + - name: Run Tests + shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" + run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors + + - name: Alert on Failure + if: failure() && github.event_name == 'schedule' + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} + run: | + curl -X POST \ + -H 'Content-type: application/json' \ + --data '{"text":"Scheduled Connector Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \ + $SLACK_WEBHOOK diff --git a/.github/workflows/run-it.yml b/.github/workflows/run-it.yml index 7c0c1814c3b..0ca0031c64c 100644 --- a/.github/workflows/run-it.yml +++ b/.github/workflows/run-it.yml @@ -28,30 +28,20 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - - name: Build Web Docker image - uses: docker/build-push-action@v5 - with: - context: ./web - file: ./web/Dockerfile - platforms: linux/arm64 - pull: true - push: true - load: true - tags: danswer/danswer-web-server:it - cache-from: type=registry,ref=danswer/danswer-web-server:it - cache-to: | - type=registry,ref=danswer/danswer-web-server:it,mode=max - type=inline + # NOTE: we don't need to build the Web Docker image since it's not used + # during the IT for now. We have a separate action to verify it builds + # succesfully + - name: Pull Web Docker image + run: | + docker pull danswer/danswer-web-server:latest + docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:it - name: Build Backend Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/Dockerfile platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/danswer-backend:it cache-from: type=registry,ref=danswer/danswer-backend:it cache-to: | @@ -59,14 +49,11 @@ jobs: type=inline - name: Build Model Server Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/danswer-model-server:it cache-from: type=registry,ref=danswer/danswer-model-server:it cache-to: | @@ -74,14 +61,11 @@ jobs: type=inline - name: Build integration test Docker image - uses: docker/build-push-action@v5 + uses: ./.github/actions/custom-build-and-push with: context: ./backend file: ./backend/tests/integration/Dockerfile platforms: linux/arm64 - pull: true - push: true - load: true tags: danswer/integration-test-runner:it cache-from: type=registry,ref=danswer/integration-test-runner:it cache-to: | @@ -92,8 +76,11 @@ jobs: run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ + AUTH_TYPE=basic \ + REQUIRE_EMAIL_VERIFICATION=false \ + DISABLE_TELEMETRY=true \ IMAGE_TAG=it \ - docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build + docker compose -f docker-compose.dev.yml -p danswer-stack up -d id: start_docker - name: Wait for service to be ready @@ -137,6 +124,7 @@ jobs: -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e VESPA_HOST=index \ + -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ danswer/integration-test-runner:it diff --git a/.gitignore b/.gitignore index 15bed8a5983..a851e719116 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,6 @@ .idea .python-version /deployment/data/nginx/app.conf -.vscode/launch.json +.vscode/ *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index b3fae8cee73..89faca0abf0 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -1,5 +1,5 @@ -# Copy this file to .env at the base of the repo and fill in the values -# This will help with development iteration speed and reduce repeat tasks for dev +# Copy this file to .env in the .vscode folder +# Fill in the values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI # Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes # For local dev, often user Authentication is not needed @@ -15,7 +15,7 @@ LOG_LEVEL=debug # This passes top N results to LLM an additional time for reranking prior to answer generation # This step is quite heavy on token usage so we disable it for dev generally -DISABLE_LLM_DOC_RELEVANCE=True +DISABLE_LLM_DOC_RELEVANCE=False # Useful if you want to toggle auth on/off (google_oauth/OIDC specifically) @@ -27,9 +27,9 @@ REQUIRE_EMAIL_VERIFICATION=False # Set these so if you wipe the DB, you don't end up having to go through the UI every time GEN_AI_API_KEY= -# If answer quality isn't important for dev, use 3.5 turbo due to it being cheaper -GEN_AI_MODEL_VERSION=gpt-3.5-turbo -FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo +# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper +GEN_AI_MODEL_VERSION=gpt-4o +FAST_GEN_AI_MODEL_VERSION=gpt-4o # For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time # Only needed if using DanswerBot @@ -38,7 +38,7 @@ FAST_GEN_AI_MODEL_VERSION=gpt-3.5-turbo # Python stuff -PYTHONPATH=./backend +PYTHONPATH=../backend PYTHONUNBUFFERED=1 @@ -49,4 +49,3 @@ BING_API_KEY= # Enable the full set of Danswer Enterprise Edition features # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False - diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 9aaadb32acf..c733800981c 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -1,15 +1,23 @@ -/* - - Copy this file into '.vscode/launch.json' or merge its - contents into your existing configurations. - -*/ +/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", + "compounds": [ + { + "name": "Run All Danswer Services", + "configurations": [ + "Web Server", + "Model Server", + "API Server", + "Indexing", + "Background Jobs", + "Slack Bot" + ] + } + ], "configurations": [ { "name": "Web Server", @@ -17,7 +25,7 @@ "request": "launch", "cwd": "${workspaceRoot}/web", "runtimeExecutable": "npm", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "runtimeArgs": [ "run", "dev" ], @@ -25,11 +33,12 @@ }, { "name": "Model Server", - "type": "python", + "consoleName": "Model Server", + "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" @@ -39,16 +48,16 @@ "--reload", "--port", "9000" - ], - "consoleTitle": "Model Server" + ] }, { "name": "API Server", - "type": "python", + "consoleName": "API Server", + "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", @@ -59,32 +68,32 @@ "--reload", "--port", "8080" - ], - "consoleTitle": "API Server" + ] }, { "name": "Indexing", - "type": "python", + "consoleName": "Indexing", + "type": "debugpy", "request": "launch", "program": "danswer/background/update.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "ENABLE_MULTIPASS_INDEXING": "false", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." - }, - "consoleTitle": "Indexing" + } }, // Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev { "name": "Background Jobs", - "type": "python", + "consoleName": "Background Jobs", + "type": "debugpy", "request": "launch", "program": "scripts/dev_run_background_jobs.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", @@ -93,18 +102,18 @@ }, "args": [ "--no-indexing" - ], - "consoleTitle": "Background Jobs" + ] }, // For the listner to access the Slack API, // DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project { "name": "Slack Bot", - "type": "python", + "consoleName": "Slack Bot", + "type": "debugpy", "request": "launch", "program": "danswer/danswerbot/slack/listener.py", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", @@ -113,11 +122,12 @@ }, { "name": "Pytest", - "type": "python", + "consoleName": "Pytest", + "type": "debugpy", "request": "launch", "module": "pytest", "cwd": "${workspaceFolder}/backend", - "envFile": "${workspaceFolder}/.env", + "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", @@ -128,18 +138,16 @@ // Specify a sepcific module/test to run or provide nothing to run all tests //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" ] - } - ], - "compounds": [ + }, { - "name": "Run Danswer", - "configurations": [ - "Web Server", - "Model Server", - "API Server", - "Indexing", - "Background Jobs", - ] + "name": "Clear and Restart External Volumes and Containers", + "type": "node", + "request": "launch", + "runtimeExecutable": "bash", + "runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal", + "stopOnEntry": true } ] } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 116e78b6f19..3e4415188a1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,23 +48,26 @@ We would love to see you there! ## Get Started 🚀 -Danswer being a fully functional app, relies on some external pieces of software, specifically: +Danswer being a fully functional app, relies on some external software, specifically: - [Postgres](https://www.postgresql.org/) (Relational DB) - [Vespa](https://vespa.ai/) (Vector DB/Search Engine) +- [Redis](https://redis.io/) (Cache) +- [Nginx](https://nginx.org/) (Not needed for development flows generally) -This guide provides instructions to set up the Danswer specific services outside of Docker because it's easier for -development purposes but also feel free to just use the containers and update with local changes by providing the -`--build` flag. + +> **Note:** +> This guide provides instructions to build and run Danswer locally from source with Docker containers providing the above external software. We believe this combination is easier for +> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Danswer stack within Docker below. ### Local Set Up -It is recommended to use Python version 3.11 +Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme. If using a lower version, modifications will have to be made to the code. -If using a higher version, the version of Tensorflow we use may not be available for your platform. +If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python). -#### Installing Requirements +#### Backend: Python requirements Currently, we use pip and recommend creating a virtual environment. For convenience here's a command for it: @@ -73,8 +76,9 @@ python -m venv .venv source .venv/bin/activate ``` ---> Note that this virtual environment MUST NOT be set up WITHIN the danswer -directory +> **Note:** +> This virtual environment MUST NOT be set up WITHIN the danswer directory if you plan on using mypy within certain IDEs. +> For simplicity, we recommend setting up the virtual environment outside of the danswer directory. _For Windows, activate the virtual environment using Command Prompt:_ ```bash @@ -89,34 +93,38 @@ Install the required python dependencies: ```bash pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/dev.txt +pip install -r danswer/backend/requirements/ee.txt pip install -r danswer/backend/requirements/model_server.txt ``` -Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. -Once the above is done, navigate to `danswer/web` run: +Install Playwright for Python (headless browser required by the Web Connector) + +In the activated Python virtualenv, install Playwright for Python by running: ```bash -npm i +playwright install ``` -Install Playwright (required by the Web Connector) +You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path. -> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again. -This will update the path to include playwright +#### Frontend: Node dependencies -Then install Playwright by running: +Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. +Once the above is done, navigate to `danswer/web` run: ```bash -playwright install +npm i ``` +#### Docker containers for external software +You will need Docker installed to run these containers. -#### Dependent Docker Containers -First navigate to `danswer/deployment/docker_compose`, then start up Vespa and Postgres with: +First navigate to `danswer/deployment/docker_compose`, then start up Postgres/Vespa/Redis with: ```bash -docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db +docker compose -f docker-compose.dev.yml -p danswer-stack up -d index relational_db cache ``` -(index refers to Vespa and relational_db refers to Postgres) +(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis) + -#### Running Danswer +#### Running Danswer locally To start the frontend, navigate to `danswer/web` and run: ```bash npm run dev @@ -127,11 +135,10 @@ Navigate to `danswer/backend` and run: ```bash uvicorn model_server.main:app --reload --port 9000 ``` + _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash -powershell -Command " - uvicorn model_server.main:app --reload --port 9000 -" +powershell -Command "uvicorn model_server.main:app --reload --port 9000" ``` The first time running Danswer, you will need to run the DB migrations for Postgres. @@ -154,6 +161,7 @@ To run the backend API server, navigate back to `danswer/backend` and run: ```bash AUTH_TYPE=disabled uvicorn danswer.main:app --reload --port 8080 ``` + _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " @@ -162,20 +170,58 @@ powershell -Command " " ``` -Note: if you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services. +> **Note:** +> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services. + +#### Wrapping up + +You should now have 4 servers running: + +- Web server +- Backend API +- Model server +- Background jobs + +Now, visit `http://localhost:3000` in your browser. You should see the Danswer onboarding wizard where you can connect your external LLM provider to Danswer. + +You've successfully set up a local Danswer instance! 🏁 + +#### Running the Danswer application in a container + +You can run the full Danswer application stack from pre-built images including all external software dependencies. + +Navigate to `danswer/deployment/docker_compose` and run: + +```bash +docker compose -f docker-compose.dev.yml -p danswer-stack up -d +``` + +After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Danswer. + +If you want to make changes to Danswer and run those changes in Docker, you can also build a local version of the Danswer container images that incorporates your changes like so: + +```bash +docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build +``` ### Formatting and Linting #### Backend For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports). First, install pre-commit (if you don't have it already) following the instructions [here](https://pre-commit.com/#installation). + +With the virtual environment active, install the pre-commit library with: +```bash +pip install pre-commit +``` + Then, from the `danswer/backend` directory, run: ```bash pre-commit install ``` Additionally, we use `mypy` for static type checking. -Danswer is fully type-annotated, and we would like to keep it that way! +Danswer is fully type-annotated, and we want to keep it that way! To run the mypy checks manually, run `python -m mypy .` from the `danswer/backend` directory. @@ -186,6 +232,7 @@ Please double check that prettier passes before creating a pull request. ### Release Process -Danswer follows the semver versioning standard. +Danswer loosely follows the SemVer versioning standard. +Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=danswer%2F). diff --git a/CONTRIBUTING_MACOS.md b/CONTRIBUTING_MACOS.md new file mode 100644 index 00000000000..519eccffd51 --- /dev/null +++ b/CONTRIBUTING_MACOS.md @@ -0,0 +1,31 @@ +## Some additional notes for Mac Users +The base instructions to set up the development environment are located in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md). + +### Setting up Python +Ensure [Homebrew](https://brew.sh/) is already set up. + +Then install python 3.11. +```bash +brew install python@3.11 +``` + +Add python 3.11 to your path: add the following line to ~/.zshrc +``` +export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH" +``` + +> **Note:** +> You will need to open a new terminal for the path change above to take effect. + + +### Setting up Docker +On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and +ensure it is running before continuing with the docker commands. + + +### Formatting and Linting +MacOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly. +After installing pre-commit, run the following command: +```bash +sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit +``` \ No newline at end of file diff --git a/backend/Dockerfile b/backend/Dockerfile index 17e0be8c239..fc7bcc586d7 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -75,8 +75,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')" # Pre-downloading NLTK for setups with limited egress RUN python -c "import nltk; \ nltk.download('stopwords', quiet=True); \ -nltk.download('wordnet', quiet=True); \ nltk.download('punkt', quiet=True);" +# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed # Set up application files WORKDIR /app diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 8c028202bfc..154d6ff3d66 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -16,7 +16,9 @@ # Interpret the config file for Python logging. # This line sets up loggers basically. -if config.config_file_name is not None: +if config.config_file_name is not None and config.attributes.get( + "configure_logger", True +): fileConfig(config.config_file_name) # add your model's MetaData object here diff --git a/backend/alembic/versions/0ebb1d516877_add_ccpair_deletion_failure_message.py b/backend/alembic/versions/0ebb1d516877_add_ccpair_deletion_failure_message.py new file mode 100644 index 00000000000..526c9449fce --- /dev/null +++ b/backend/alembic/versions/0ebb1d516877_add_ccpair_deletion_failure_message.py @@ -0,0 +1,27 @@ +"""add ccpair deletion failure message + +Revision ID: 0ebb1d516877 +Revises: 52a219fb5233 +Create Date: 2024-09-10 15:03:48.233926 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "0ebb1d516877" +down_revision = "52a219fb5233" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column("deletion_failure_message", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("connector_credential_pair", "deletion_failure_message") diff --git a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py new file mode 100644 index 00000000000..f284c7b4bf1 --- /dev/null +++ b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py @@ -0,0 +1,66 @@ +"""Add last synced and last modified to document table + +Revision ID: 52a219fb5233 +Revises: f17bf3b0d9f1 +Create Date: 2024-08-28 17:40:46.077470 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql import func + +# revision identifiers, used by Alembic. +revision = "52a219fb5233" +down_revision = "f7e58d357687" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # last modified represents the last time anything needing syncing to vespa changed + # including row metadata and the document itself. This obviously does not include + # the last_synced column. + op.add_column( + "document", + sa.Column( + "last_modified", + sa.DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + + # last synced represents the last time this document was synced to Vespa + op.add_column( + "document", + sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True), + ) + + # Set last_synced to the same value as last_modified for existing rows + op.execute( + """ + UPDATE document + SET last_synced = last_modified + """ + ) + + op.create_index( + op.f("ix_document_last_modified"), + "document", + ["last_modified"], + unique=False, + ) + + op.create_index( + op.f("ix_document_last_synced"), + "document", + ["last_synced"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_document_last_synced"), table_name="document") + op.drop_index(op.f("ix_document_last_modified"), table_name="document") + op.drop_column("document", "last_synced") + op.drop_column("document", "last_modified") diff --git a/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py b/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py new file mode 100644 index 00000000000..20e33d0e227 --- /dev/null +++ b/backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py @@ -0,0 +1,158 @@ +"""migration confluence to be explicit + +Revision ID: a3795dce87be +Revises: 1f60f60c3401 +Create Date: 2024-09-01 13:52:12.006740 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import table, column + +revision = "a3795dce87be" +down_revision = "1f60f60c3401" +branch_labels: None = None +depends_on: None = None + + +def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: + from urllib.parse import urlparse + + def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: + parsed_url = urlparse(wiki_url) + wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}" + path_parts = parsed_url.path.split("/") + space = path_parts[3] + page_id = path_parts[5] if len(path_parts) > 5 else "" + return wiki_base, space, page_id + + def _extract_confluence_keys_from_datacenter_url( + wiki_url: str, + ) -> tuple[str, str, str]: + DISPLAY = "/display/" + PAGE = "/pages/" + parsed_url = urlparse(wiki_url) + wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}" + space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] + page_id = "" + if (content := parsed_url.path.split(PAGE)) and len(content) > 1: + page_id = content[1] + return wiki_base, space, page_id + + is_confluence_cloud = ( + ".atlassian.net/wiki/spaces/" in wiki_url + or ".jira.com/wiki/spaces/" in wiki_url + ) + + if is_confluence_cloud: + wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url) + else: + wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( + wiki_url + ) + + return wiki_base, space, page_id, is_confluence_cloud + + +def reconstruct_confluence_url( + wiki_base: str, space: str, page_id: str, is_cloud: bool +) -> str: + if is_cloud: + url = f"{wiki_base}/spaces/{space}" + if page_id: + url += f"/pages/{page_id}" + else: + url = f"{wiki_base}/display/{space}" + if page_id: + url += f"/pages/{page_id}" + return url + + +def upgrade() -> None: + connector = table( + "connector", + column("id", sa.Integer), + column("source", sa.String()), + column("input_type", sa.String()), + column("connector_specific_config", postgresql.JSONB), + ) + + # Fetch all Confluence connectors + connection = op.get_bind() + confluence_connectors = connection.execute( + sa.select(connector).where( + sa.and_( + connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" + ) + ) + ).fetchall() + + for row in confluence_connectors: + config = row.connector_specific_config + wiki_page_url = config["wiki_page_url"] + wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url( + wiki_page_url + ) + + new_config = { + "wiki_base": wiki_base, + "space": space, + "page_id": page_id, + "is_cloud": is_cloud, + } + + for key, value in config.items(): + if key not in ["wiki_page_url"]: + new_config[key] = value + + op.execute( + connector.update() + .where(connector.c.id == row.id) + .values(connector_specific_config=new_config) + ) + + +def downgrade() -> None: + connector = table( + "connector", + column("id", sa.Integer), + column("source", sa.String()), + column("input_type", sa.String()), + column("connector_specific_config", postgresql.JSONB), + ) + + confluence_connectors = ( + op.get_bind() + .execute( + sa.select(connector).where( + connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" + ) + ) + .fetchall() + ) + + for row in confluence_connectors: + config = row.connector_specific_config + if all(key in config for key in ["wiki_base", "space", "is_cloud"]): + wiki_page_url = reconstruct_confluence_url( + config["wiki_base"], + config["space"], + config.get("page_id", ""), + config["is_cloud"], + ) + + new_config = {"wiki_page_url": wiki_page_url} + new_config.update( + { + k: v + for k, v in config.items() + if k not in ["wiki_base", "space", "page_id", "is_cloud"] + } + ) + + op.execute( + connector.update() + .where(connector.c.id == row.id) + .values(connector_specific_config=new_config) + ) diff --git a/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py new file mode 100644 index 00000000000..2d45a15f2c6 --- /dev/null +++ b/backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py @@ -0,0 +1,26 @@ +"""add support for litellm proxy in reranking + +Revision ID: ba98eba0f66a +Revises: bceb1e139447 +Create Date: 2024-09-06 10:36:04.507332 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "ba98eba0f66a" +down_revision = "bceb1e139447" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("search_settings", "rerank_api_url") diff --git a/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py new file mode 100644 index 00000000000..968500e6aaf --- /dev/null +++ b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py @@ -0,0 +1,26 @@ +"""Add base_url to CloudEmbeddingProvider + +Revision ID: bceb1e139447 +Revises: a3795dce87be +Create Date: 2024-08-28 17:00:52.554580 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "bceb1e139447" +down_revision = "a3795dce87be" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "embedding_provider", sa.Column("api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("embedding_provider", "api_url") diff --git a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py new file mode 100644 index 00000000000..2d8e7402e48 --- /dev/null +++ b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py @@ -0,0 +1,26 @@ +"""add has_web_login column to user + +Revision ID: f7e58d357687 +Revises: bceb1e139447 +Create Date: 2024-09-07 20:20:54.522620 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "f7e58d357687" +down_revision = "ba98eba0f66a" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column( + "user", + sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"), + ) + + +def downgrade() -> None: + op.drop_column("user", "has_web_login") diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 5501980ab48..9088ddf8425 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -3,21 +3,49 @@ from danswer.access.models import DocumentAccess from danswer.access.utils import prefix_user from danswer.configs.constants import PUBLIC_DOC_PAT -from danswer.db.document import get_acccess_info_for_documents +from danswer.db.document import get_access_info_for_document +from danswer.db.document import get_access_info_for_documents from danswer.db.models import User from danswer.utils.variable_functionality import fetch_versioned_implementation +def _get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + info = get_access_info_for_document( + db_session=db_session, + document_id=document_id, + ) + + if not info: + return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + + return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2]) + + +def get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + versioned_get_access_for_document_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_access_for_document" + ) + return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore + + def _get_access_for_documents( document_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: - document_access_info = get_acccess_info_for_documents( + document_access_info = get_access_info_for_documents( db_session=db_session, document_ids=document_ids, ) return { - document_id: DocumentAccess.build(user_ids, [], is_public) + document_id: DocumentAccess.build( + user_ids=user_ids, user_groups=[], is_public=is_public + ) for document_id, user_ids, is_public in document_access_info } diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 9e0553991cc..db8a97ceb04 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -33,7 +33,9 @@ class UserRead(schemas.BaseUser[uuid.UUID]): class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC + has_web_login: bool | None = True class UserUpdate(schemas.BaseUserUpdate): role: UserRole + has_web_login: bool | None = True diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index c3851ff1990..56d9a99eb33 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -16,7 +16,9 @@ from fastapi import Request from fastapi import Response from fastapi import status +from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager +from fastapi_users import exceptions from fastapi_users import FastAPIUsers from fastapi_users import models from fastapi_users import schemas @@ -33,6 +35,7 @@ from danswer.auth.invited_users import get_invited_users from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole +from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import EMAIL_FROM @@ -67,23 +70,6 @@ logger = setup_logger() -def validate_curator_request(groups: list | None, is_public: bool) -> None: - if is_public: - detail = "Curators cannot create public objects" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - if not groups: - detail = "Curators must specify 1+ groups" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - - def is_user_admin(user: User | None) -> bool: if AUTH_TYPE == AuthType.DISABLED: return True @@ -201,7 +187,7 @@ async def create( user_create: schemas.UC | UserCreate, safe: bool = False, request: Optional[Request] = None, - ) -> models.UP: + ) -> User: verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) if hasattr(user_create, "role"): @@ -210,7 +196,27 @@ async def create( user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC - return await super().create(user_create, safe=safe, request=request) # type: ignore + user = None + try: + user = await super().create(user_create, safe=safe, request=request) # type: ignore + except exceptions.UserAlreadyExists: + user = await self.get_by_email(user_create.email) + # Handle case where user has used product outside of web and is now creating an account through web + if ( + not user.has_web_login + and hasattr(user_create, "has_web_login") + and user_create.has_web_login + ): + user_update = UserUpdate( + password=user_create.password, + has_web_login=True, + role=user_create.role, + is_verified=user_create.is_verified, + ) + user = await self.update(user_update, user) + else: + raise exceptions.UserAlreadyExists() + return user async def oauth_callback( self: "BaseUserManager[models.UOAP, models.ID]", @@ -250,6 +256,17 @@ async def oauth_callback( if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY: await self.user_db.update(user, update_dict={"oidc_expiry": None}) + # Handle case where user has used product outside of web and is now creating an account through web + if not user.has_web_login: + await self.user_db.update( + user, + update_dict={ + "is_verified": is_verified_by_default, + "has_web_login": True, + }, + ) + user.is_verified = is_verified_by_default + user.has_web_login = True return user async def on_after_register( @@ -278,6 +295,22 @@ async def on_after_request_verify( send_user_verification_email(user.email, token) + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> Optional[User]: + user = await super().authenticate(credentials) + if user is None: + try: + user = await self.get_by_email(credentials.username) + if not user.has_web_login: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", + ) + except exceptions.UserNotExists: + pass + return user + async def get_user_manager( user_db: SQLAlchemyUserDatabase = Depends(get_user_db), diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 1c0e949d068..a48d8aa4a15 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1,60 +1,87 @@ import json +import traceback from datetime import timedelta from typing import Any from typing import cast -from celery import Celery # type: ignore +import redis +from celery import Celery +from celery import signals +from celery import Task from celery.contrib.abortable import AbortableTask # type: ignore +from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import TaskRevokedError +from celery.signals import beat_init +from celery.signals import worker_init +from celery.states import READY_STATES +from celery.utils.log import get_task_logger +from redis import Redis +from sqlalchemy import inspect from sqlalchemy import text from sqlalchemy.orm import Session +from danswer.access.access import get_access_for_document +from danswer.background.celery.celery_redis import RedisConnectorCredentialPair +from danswer.background.celery.celery_redis import RedisDocumentSet +from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair from danswer.background.celery.celery_utils import should_prune_cc_pair -from danswer.background.celery.celery_utils import should_sync_doc_set from danswer.background.connector_deletion import delete_connector_credential_pair from danswer.background.connector_deletion import delete_connector_credential_pair_batch from danswer.background.task_utils import build_celery_task_wrapper from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task -from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.constants import POSTGRES_CELERY_APP_NAME +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerRedisLocks +from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME +from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME from danswer.configs.constants import PostgresAdvisoryLocks from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType -from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.connector_credential_pair import add_deletion_failure_message +from danswer.db.connector_credential_pair import ( + get_connector_credential_pair, +) from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed +from danswer.db.document import count_documents_by_needs_sync +from danswer.db.document import get_document from danswer.db.document import get_documents_for_connector_credential_pair -from danswer.db.document import prepare_to_modify_documents +from danswer.db.document import mark_document_as_synced from danswer.db.document_set import delete_document_set +from danswer.db.document_set import fetch_document_set_for_document from danswer.db.document_set import fetch_document_sets -from danswer.db.document_set import fetch_document_sets_for_documents -from danswer.db.document_set import fetch_documents_for_document_set_paginated from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced -from danswer.db.engine import build_connection_string from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import SYNC_DB_API +from danswer.db.engine import init_sqlalchemy_engine from danswer.db.models import DocumentSet +from danswer.db.models import UserGroup from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest +from danswer.redis.redis_pool import RedisPool from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) +from danswer.utils.variable_functionality import noop_fallback logger = setup_logger() -connection_string = build_connection_string( - db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME -) -celery_broker_url = f"sqla+{connection_string}" -celery_backend_url = f"db+{connection_string}" -celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) +redis_pool = RedisPool() -_SYNC_BATCH_SIZE = 100 +celery_app = Celery(__name__) +celery_app.config_from_object( + "danswer.background.celery.celeryconfig" +) # Load configuration from 'celeryconfig.py' ##### @@ -72,6 +99,7 @@ def cleanup_connector_credential_pair_task( Needs to potentially update a large number of Postgres and Vespa docs, including deleting them or updating the ACL""" engine = get_sqlalchemy_engine() + with Session(engine) as db_session: # validate that the connector / credential pair is deletable cc_pair = get_connector_credential_pair( @@ -84,14 +112,13 @@ def cleanup_connector_credential_pair_task( f"Cannot run deletion attempt - connector_credential_pair with Connector ID: " f"{connector_id} and Credential ID: {credential_id} does not exist." ) - - deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( - connector_credential_pair=cc_pair, db_session=db_session - ) - if deletion_attempt_disallowed_reason: - raise ValueError(deletion_attempt_disallowed_reason) - try: + deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( + connector_credential_pair=cc_pair, db_session=db_session + ) + if deletion_attempt_disallowed_reason: + raise ValueError(deletion_attempt_disallowed_reason) + # The bulk of the work is in here, updates Postgres and Vespa curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( @@ -102,8 +129,16 @@ def cleanup_connector_credential_pair_task( document_index=document_index, cc_pair=cc_pair, ) + except Exception as e: - logger.exception(f"Failed to run connector_deletion due to {e}") + stack_trace = traceback.format_exc() + error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" + add_deletion_failure_message(db_session, cc_pair.id, error_message) + task_logger.exception( + f"Failed to run connector_deletion. " + f"connector_id={connector_id} credential_id={credential_id}\n" + f"Stack Trace:\n{stack_trace}" + ) raise e @@ -122,7 +157,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) if not cc_pair: - logger.warning(f"ccpair not found for {connector_id} {credential_id}") + task_logger.warning( + f"ccpair not found for {connector_id} {credential_id}" + ) return runnable_connector = instantiate_connector( @@ -154,12 +191,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: ) if len(doc_ids_to_remove) == 0: - logger.info( + task_logger.info( f"No docs to prune from {cc_pair.connector.source} connector" ) return - logger.info( + task_logger.info( f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" ) delete_connector_credential_pair_batch( @@ -169,124 +206,202 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: document_index=document_index, ) except Exception as e: - logger.exception( - f"Failed to run pruning for connector id {connector_id} due to {e}" + task_logger.exception( + f"Failed to run pruning for connector id {connector_id}." ) raise e -@build_celery_task_wrapper(name_document_set_sync_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_document_set_task(document_set_id: int) -> None: - """For document sets marked as not up to date, sync the state from postgres - into the datastore. Also handles deletions.""" - - def _sync_document_batch(document_ids: list[str], db_session: Session) -> None: - logger.debug(f"Syncing document sets for: {document_ids}") - - # Acquires a lock on the documents so that no other process can modify them - with prepare_to_modify_documents( - db_session=db_session, document_ids=document_ids - ): - # get current state of document sets for these documents - document_set_map = { - document_id: document_sets - for document_id, document_sets in fetch_document_sets_for_documents( - document_ids=document_ids, db_session=db_session - ) - } +def try_generate_stale_document_sync_tasks( + db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + # the fence is up, do nothing + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + return None - # get current state of document sets for these documents - document_set_map = { - document_id: document_sets - for document_id, document_sets in fetch_document_sets_for_documents( - document_ids=document_ids, db_session=db_session - ) - } + r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset + + # add tasks to celery and build up the task set to monitor in redis + stale_doc_count = count_documents_by_needs_sync(db_session) + if stale_doc_count == 0: + return None + + task_logger.info( + f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." + ) + + # rkuo: we could technically sync all stale docs in one big pass. + # but I feel it's more understandable to group the docs by cc_pair + total_tasks_generated = 0 + cc_pairs = get_connector_credential_pairs(db_session) + for cc_pair in cc_pairs: + rc = RedisConnectorCredentialPair(cc_pair.id) + tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat) - # update Vespa - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + if tasks_generated is None: + continue + + if tasks_generated == 0: + continue + + task_logger.info( + f"RedisConnector.generate_tasks finished. " + f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" ) - update_requests = [ - UpdateRequest( - document_ids=[document_id], - document_sets=set(document_set_map.get(document_id, [])), - ) - for document_id in document_ids - ] - document_index.update(update_requests=update_requests) - # Commit to release the locks - db_session.commit() + total_tasks_generated += tasks_generated - with Session(get_sqlalchemy_engine()) as db_session: - try: - cursor = None - while True: - document_batch, cursor = fetch_documents_for_document_set_paginated( - document_set_id=document_set_id, - db_session=db_session, - current_only=False, - last_document_id=cursor, - limit=_SYNC_BATCH_SIZE, - ) - _sync_document_batch( - document_ids=[document.id for document in document_batch], - db_session=db_session, - ) - if cursor is None: - break - - # if there are no connectors, then delete the document set. Otherwise, just - # mark it as successfully synced. - document_set = cast( - DocumentSet, - get_document_set_by_id( - db_session=db_session, document_set_id=document_set_id - ), - ) # casting since we "know" a document set with this ID exists - if not document_set.connector_credential_pairs: - delete_document_set( - document_set_row=document_set, db_session=db_session - ) - logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" - ) - else: - mark_document_set_as_synced( - document_set_id=document_set_id, db_session=db_session - ) - logger.info(f"Document set sync for '{document_set_id}' complete!") + task_logger.info( + f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}" + ) + + r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated) + return total_tasks_generated - except Exception: - logger.exception("Failed to sync document set %s", document_set_id) - raise + +def try_generate_document_set_sync_tasks( + document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rds = RedisDocumentSet(document_set.id) + + # don't generate document set sync tasks if tasks are still pending + if r.exists(rds.fence_key): + return None + + # don't generate sync tasks if we're up to date + if document_set.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rds.taskset_key) + + task_logger.info( + f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}" + ) + + # Add all documents that need to be updated into the queue + tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"RedisDocumentSet.generate_tasks finished. " + f"document_set_id={document_set.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rds.fence_key, tasks_generated) + return tasks_generated + + +def try_generate_user_group_sync_tasks( + usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rug = RedisUserGroup(usergroup.id) + + # don't generate sync tasks if tasks are still pending + if r.exists(rug.fence_key): + return None + + if usergroup.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rug.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}") + tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"generate_tasks finished. " + f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rug.fence_key, tasks_generated) + return tasks_generated ##### # Periodic Tasks ##### @celery_app.task( - name="check_for_document_sets_sync_task", + name="check_for_vespa_sync_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_document_sets_sync_task() -> None: - """Runs periodically to check if any sync tasks should be run and adds them - to the queue""" - with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced - document_set_info = fetch_document_sets( - user_id=None, db_session=db_session, include_outdated=True - ) - for document_set, _ in document_set_info: - if should_sync_doc_set(document_set, db_session): - logger.info(f"Syncing the {document_set.name} document set") - sync_document_set_task.apply_async( - kwargs=dict(document_set_id=document_set.id), +def check_for_vespa_sync_task() -> None: + """Runs periodically to check if any document needs syncing. + Generates sets of tasks for Celery if syncing is needed.""" + + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + try_generate_stale_document_sync_tasks(db_session, r, lock_beat) + + # check if any document sets are not synced + document_set_info = fetch_document_sets( + user_id=None, db_session=db_session, include_outdated=True + ) + for document_set, _ in document_set_info: + try_generate_document_set_sync_tasks( + document_set, db_session, r, lock_beat ) + # check if any user groups are not synced + try: + fetch_user_groups = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_groups" + ) + + user_groups = fetch_user_groups( + db_session=db_session, only_up_to_date=False + ) + for usergroup in user_groups: + try_generate_user_group_sync_tasks( + usergroup, db_session, r, lock_beat + ) + except ModuleNotFoundError: + # Always exceptions on the MIT version, which is expected + pass + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception") + finally: + if lock_beat.owned(): + lock_beat.release() + @celery_app.task( name="check_for_cc_pair_deletion_task", @@ -295,11 +410,14 @@ def check_for_document_sets_sync_task() -> None: def check_for_cc_pair_deletion_task() -> None: """Runs periodically to check if any deletion tasks should be run""" with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced + # check if any cc pairs are up for deletion cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: if should_kick_off_deletion_of_cc_pair(cc_pair, db_session): - logger.notice(f"Deleting the {cc_pair.name} connector credential pair") + task_logger.info( + f"Deleting the {cc_pair.name} connector credential pair" + ) + cleanup_connector_credential_pair_task.apply_async( kwargs=dict( connector_id=cc_pair.connector.id, @@ -346,7 +464,9 @@ def kombu_message_cleanup_task(self: Any) -> int: db_session.commit() if ctx["deleted"] > 0: - logger.info(f"Deleted {ctx['deleted']} orphaned messages from kombu_message.") + task_logger.info( + f"Deleted {ctx['deleted']} orphaned messages from kombu_message." + ) return ctx["deleted"] @@ -371,6 +491,15 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: bool: Returns True if there are more rows to process, False if not. """ + inspector = inspect(db_session.bind) + if not inspector: + return False + + # With the move to redis as celery's broker and backend, kombu tables may not even exist. + # We can fail silently. + if not inspector.has_table("kombu_message"): + return False + query = text( """ SELECT id, timestamp, payload @@ -411,12 +540,6 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: ) if result.rowcount > 0: # type: ignore ctx["deleted"] += 1 - else: - task_name = payload["headers"]["task"] - logger.warning( - f"Message found for task older than {ctx['cleanup_age']} days. " - f"id={task_id} name={task_name}" - ) ctx["last_processed_id"] = msg[0] @@ -440,7 +563,7 @@ def check_for_prune_task() -> None: credential=cc_pair.credential, db_session=db_session, ): - logger.info(f"Pruning the {cc_pair.connector.name} connector") + task_logger.info(f"Pruning the {cc_pair.connector.name} connector") prune_documents_task.apply_async( kwargs=dict( @@ -450,19 +573,331 @@ def check_for_prune_task() -> None: ) +@celery_app.task( + name="vespa_metadata_sync_task", + bind=True, + soft_time_limit=45, + time_limit=60, + max_retries=3, +) +def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: + task_logger.info(f"document_id={document_id}") + + try: + with Session(get_sqlalchemy_engine()) as db_session: + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + doc = get_document(document_id, db_session) + if not doc: + return False + + # document set sync + doc_sets = fetch_document_set_for_document(document_id, db_session) + update_doc_sets: set[str] = set(doc_sets) + + # User group sync + doc_access = get_access_for_document( + document_id=document_id, db_session=db_session + ) + update_request = UpdateRequest( + document_ids=[document_id], + document_sets=update_doc_sets, + access=doc_access, + boost=doc.boost, + hidden=doc.hidden, + ) + + # update Vespa + document_index.update(update_requests=[update_request]) + + # update db last. Worst case = we crash right before this and + # the sync might repeat again later + mark_document_as_synced(document_id, db_session) + except SoftTimeLimitExceeded: + task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") + except Exception as e: + task_logger.exception("Unexpected exception") + + # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 + countdown = 2 ** (self.request.retries + 4) + self.retry(exc=e, countdown=countdown) + + return True + + +@signals.task_postrun.connect +def celery_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + """We handle this signal in order to remove completed tasks + from their respective tasksets. This allows us to track the progress of document set + and user group syncs. + + This function runs after any task completes (both success and failure) + Note that this signal does not fire on a task that failed to complete and is going + to be retried. + """ + if not task: + return + + task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") + # logger.debug(f"Result: {retval}") + + if state not in READY_STATES: + return + + if not task_id: + return + + if task_id.startswith(RedisConnectorCredentialPair.PREFIX): + r = redis_pool.get_client() + r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) + return + + if task_id.startswith(RedisDocumentSet.PREFIX): + r = redis_pool.get_client() + document_set_id = RedisDocumentSet.get_id_from_task_id(task_id) + if document_set_id is not None: + rds = RedisDocumentSet(document_set_id) + r.srem(rds.taskset_key, task_id) + return + + if task_id.startswith(RedisUserGroup.PREFIX): + r = redis_pool.get_client() + usergroup_id = RedisUserGroup.get_id_from_task_id(task_id) + if usergroup_id is not None: + rug = RedisUserGroup(usergroup_id) + r.srem(rug.taskset_key, task_id) + return + + +def monitor_connector_taskset(r: Redis) -> None: + fence_value = r.get(RedisConnectorCredentialPair.get_fence_key()) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = r.scard(RedisConnectorCredentialPair.get_taskset_key()) + task_logger.info(f"Stale documents: remaining={count} initial={initial_count}") + if count == 0: + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + task_logger.info(f"Successfully synced stale documents. count={initial_count}") + + +def monitor_document_set_taskset( + key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key) + if document_set_id is None: + task_logger.warning("could not parse document set id from {key}") + return + + rds = RedisDocumentSet(document_set_id) + + fence_value = r.get(rds.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rds.taskset_key)) + task_logger.info( + f"document_set_id={document_set_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + document_set = cast( + DocumentSet, + get_document_set_by_id(db_session=db_session, document_set_id=document_set_id), + ) # casting since we "know" a document set with this ID exists + if document_set: + if not document_set.connector_credential_pairs: + # if there are no connectors, then delete the document set. + delete_document_set(document_set_row=document_set, db_session=db_session) + task_logger.info( + f"Successfully deleted document set with ID: '{document_set_id}'!" + ) + else: + mark_document_set_as_synced(document_set_id, db_session) + task_logger.info( + f"Successfully synced document set with ID: '{document_set_id}'!" + ) + + r.delete(rds.taskset_key) + r.delete(rds.fence_key) + + +def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None: + key = key_bytes.decode("utf-8") + usergroup_id = RedisUserGroup.get_id_from_fence_key(key) + if not usergroup_id: + task_logger.warning("Could not parse usergroup id from {key}") + return + + rug = RedisUserGroup(usergroup_id) + fence_value = r.get(rug.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rug.taskset_key)) + task_logger.info( + f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + try: + fetch_user_group = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_group" + ) + except ModuleNotFoundError: + task_logger.exception( + "fetch_versioned_implementation failed to look up fetch_user_group." + ) + return + + user_group: UserGroup | None = fetch_user_group( + db_session=db_session, user_group_id=usergroup_id + ) + if user_group: + if user_group.is_up_for_deletion: + delete_user_group = fetch_versioned_implementation_with_fallback( + "danswer.db.user_group", "delete_user_group", noop_fallback + ) + + delete_user_group(db_session=db_session, user_group=user_group) + task_logger.info(f" Deleted usergroup. id='{usergroup_id}'") + else: + mark_user_group_as_synced = fetch_versioned_implementation_with_fallback( + "danswer.db.user_group", "mark_user_group_as_synced", noop_fallback + ) + + mark_user_group_as_synced(db_session=db_session, user_group=user_group) + task_logger.info(f"Synced usergroup. id='{usergroup_id}'") + + r.delete(rug.taskset_key) + r.delete(rug.fence_key) + + +@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300) +def monitor_vespa_sync() -> None: + """This is a celery beat task that monitors and finalizes metadata sync tasksets. + It scans for fence values and then gets the counts of any associated tasksets. + If the count is 0, that means all tasks finished and we should clean up. + + This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't + do anything too expensive in this function! + """ + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # prevent overlapping tasks + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + monitor_connector_taskset(r) + + for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + monitor_document_set_taskset(key_bytes, r, db_session) + + for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + monitor_usergroup_taskset(key_bytes, r, db_session) + + # + # r_celery = celery_app.broker_connection().channel().client + # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) + # task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + finally: + if lock_beat.owned(): + lock_beat.release() + + +@beat_init.connect +def on_beat_init(sender: Any, **kwargs: Any) -> None: + init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME) + + # TODO(rkuo): this is singleton work that should be done on startup exactly once + # if we run multiple workers, we'll need to centralize where this cleanup happens + r = redis_pool.get_client() + + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) + + ##### # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "check-for-document-set-sync": { - "task": "check_for_document_sets_sync_task", + "check-for-vespa-sync": { + "task": "check_for_vespa_sync_task", "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, "check-for-cc-pair-deletion": { "task": "check_for_cc_pair_deletion_task", # don't need to check too often, since we kick off a deletion initially # during the API call that actually marks the CC pair for deletion "schedule": timedelta(minutes=1), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } celery_app.conf.beat_schedule.update( @@ -470,6 +905,7 @@ def check_for_prune_task() -> None: "check-for-prune": { "task": "check_for_prune_task", "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } ) @@ -478,6 +914,16 @@ def check_for_prune_task() -> None: "kombu-message-cleanup": { "task": "kombu_message_cleanup_task", "schedule": timedelta(seconds=3600), + "options": {"priority": DanswerCeleryPriority.LOWEST}, + }, + } +) +celery_app.conf.beat_schedule.update( + { + "monitor-vespa-sync": { + "task": "monitor_vespa_sync", + "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, }, } ) diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py new file mode 100644 index 00000000000..bf82f0a7274 --- /dev/null +++ b/backend/danswer/background/celery/celery_redis.py @@ -0,0 +1,299 @@ +# These are helper objects for tracking the keys we need to write in redis +import time +from abc import ABC +from abc import abstractmethod +from typing import cast +from uuid import uuid4 + +import redis +from celery import Celery +from redis import Redis +from sqlalchemy.orm import Session + +from danswer.background.celery.celeryconfig import CELERY_SEPARATOR +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.document import ( + construct_document_select_for_connector_credential_pair_by_needs_sync, +) +from danswer.db.document_set import construct_document_select_by_docset +from danswer.utils.variable_functionality import fetch_versioned_implementation + + +class RedisObjectHelper(ABC): + PREFIX = "base" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def __init__(self, id: int): + self._id: int = id + + @property + def task_id_prefix(self) -> str: + return f"{self.PREFIX}_{self._id}" + + @property + def fence_key(self) -> str: + # example: documentset_fence_1 + return f"{self.FENCE_PREFIX}_{self._id}" + + @property + def taskset_key(self) -> str: + # example: documentset_taskset_1 + return f"{self.TASKSET_PREFIX}_{self._id}" + + @staticmethod + def get_id_from_fence_key(key: str) -> int | None: + """ + Extracts the object ID from a fence key in the format `PREFIX_fence_X`. + + Args: + key (str): The fence key string. + + Returns: + Optional[int]: The extracted ID if the key is in the correct format, otherwise None. + """ + parts = key.split("_") + if len(parts) != 3: + return None + + try: + object_id = int(parts[2]) + except ValueError: + return None + + return object_id + + @staticmethod + def get_id_from_task_id(task_id: str) -> int | None: + """ + Extracts the object ID from a task ID string. + + This method assumes the task ID is formatted as `prefix_objectid_suffix`, where: + - `prefix` is an arbitrary string (e.g., the name of the task or entity), + - `objectid` is the ID you want to extract, + - `suffix` is another arbitrary string (e.g., a UUID). + + Example: + If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`, + this method will return the string `"1"`. + + Args: + task_id (str): The task ID string from which to extract the object ID. + + Returns: + str | None: The extracted object ID if the task ID is in the correct format, otherwise None. + """ + # example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc + parts = task_id.split("_") + if len(parts) != 3: + return None + + try: + object_id = int(parts[1]) + except ValueError: + return None + + return object_id + + @abstractmethod + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + pass + + +class RedisDocumentSet(RedisObjectHelper): + PREFIX = "documentset" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + stmt = construct_document_select_by_docset(self._id) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the set BEFORE creating the task. + redis_client.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.LOW, + ) + + async_results.append(result) + + return len(async_results) + + +class RedisUserGroup(RedisObjectHelper): + PREFIX = "usergroup" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + + try: + construct_document_select_by_usergroup = fetch_versioned_implementation( + "danswer.db.user_group", + "construct_document_select_by_usergroup", + ) + except ModuleNotFoundError: + return 0 + + stmt = construct_document_select_by_usergroup(self._id) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the set BEFORE creating the task. + redis_client.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.LOW, + ) + + async_results.append(result) + + return len(async_results) + + +class RedisConnectorCredentialPair(RedisObjectHelper): + PREFIX = "connectorsync" + FENCE_PREFIX = PREFIX + "_fence" + TASKSET_PREFIX = PREFIX + "_taskset" + + @classmethod + def get_fence_key(cls) -> str: + return RedisConnectorCredentialPair.FENCE_PREFIX + + @classmethod + def get_taskset_key(cls) -> str: + return RedisConnectorCredentialPair.TASKSET_PREFIX + + @property + def taskset_key(self) -> str: + """Notice that this is intentionally reusing the same taskset for all + connector syncs""" + # example: connector_taskset + return f"{self.TASKSET_PREFIX}" + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock, + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + cc_pair = get_connector_credential_pair_from_id(self._id, db_session) + if not cc_pair: + return None + + stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( + cc_pair.connector_id, cc_pair.credential_id + ) + for doc in db_session.scalars(stmt).yield_per(1): + current_time = time.monotonic() + if current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + + # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" + # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" + # we prefix the task id so it's easier to keep track of who created the task + # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" + custom_task_id = f"{self.task_id_prefix}_{uuid4()}" + + # add to the tracking taskset in redis BEFORE creating the celery task. + # note that for the moment we are using a single taskset key, not differentiated by cc_pair id + redis_client.sadd( + RedisConnectorCredentialPair.get_taskset_key(), custom_task_id + ) + + # Priority on sync's triggered by new indexing should be medium + result = celery_app.send_task( + "vespa_metadata_sync_task", + kwargs=dict(document_id=doc.id), + queue=DanswerCeleryQueues.VESPA_METADATA_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.MEDIUM, + ) + + async_results.append(result) + + return len(async_results) + + +def celery_get_queue_length(queue: str, r: Redis) -> int: + """This is a redis specific way to get the length of a celery queue. + It is priority aware and knows how to count across the multiple redis lists + used to implement task prioritization. + This operation is not atomic.""" + total_length = 0 + for i in range(len(DanswerCeleryPriority)): + queue_name = queue + if i > 0: + queue_name += CELERY_SEPARATOR + queue_name += str(i) + + length = r.llen(queue_name) + total_length += cast(int, length) + + return total_length diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index e4d4d13bb1d..a51bd8cca35 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -5,7 +5,6 @@ from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task -from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( @@ -22,7 +21,6 @@ from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential -from danswer.db.models import DocumentSet from danswer.db.models import TaskQueueState from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task @@ -81,21 +79,6 @@ def should_kick_off_deletion_of_cc_pair( return True -def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool: - if document_set.is_up_to_date: - return False - - task_name = name_document_set_sync_task(document_set.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): - logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") - return False - - logger.info(f"Document set {document_set.id} syncing now.") - return True - - def should_prune_cc_pair( connector: Connector, credential: Credential, db_session: Session ) -> bool: diff --git a/backend/danswer/background/celery/celeryconfig.py b/backend/danswer/background/celery/celeryconfig.py new file mode 100644 index 00000000000..898cfd4b920 --- /dev/null +++ b/backend/danswer/background/celery/celeryconfig.py @@ -0,0 +1,44 @@ +# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html +from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PASSWORD +from danswer.configs.app_configs import REDIS_PORT +from danswer.configs.app_configs import REDIS_SSL +from danswer.configs.app_configs import REDIS_SSL_CA_CERTS +from danswer.configs.app_configs import REDIS_SSL_CERT_REQS +from danswer.configs.constants import DanswerCeleryPriority + +CELERY_SEPARATOR = ":" + +CELERY_PASSWORD_PART = "" +if REDIS_PASSWORD: + CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@" + +REDIS_SCHEME = "redis" + +# SSL-specific query parameters for Redis URL +SSL_QUERY_PARAMS = "" +if REDIS_SSL: + REDIS_SCHEME = "rediss" + SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}" + if REDIS_SSL_CA_CERTS: + SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}" + +# example celery_broker_url: "redis://:password@localhost:6379/15" +broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}" + +result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}" + +# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks +# however, prefetching is bad when tasks are lengthy as those tasks +# can stall other tasks. +worker_prefetch_multiplier = 4 + +broker_transport_options = { + "priority_steps": list(range(len(DanswerCeleryPriority))), + "sep": CELERY_SEPARATOR, + "queue_order_strategy": "priority", +} + +task_default_priority = DanswerCeleryPriority.MEDIUM +task_acks_late = True diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 90883564910..c904c804d06 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -151,8 +151,7 @@ def delete_connector_credential_pair( # index attempts delete_index_attempts( db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, + cc_pair_id=cc_pair.id, ) # document sets diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index a98f4e1f5ad..86b4285361f 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -118,19 +118,19 @@ def _run_indexing( db_cc_pair = index_attempt.connector_credential_pair db_connector = index_attempt.connector_credential_pair.connector db_credential = index_attempt.connector_credential_pair.credential + earliest_index_time = ( + db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0 + ) last_successful_index_time = ( - db_connector.indexing_start.timestamp() - if index_attempt.from_beginning and db_connector.indexing_start is not None - else ( - 0.0 - if index_attempt.from_beginning - else get_last_successful_attempt_time( - connector_id=db_connector.id, - credential_id=db_credential.id, - search_settings=index_attempt.search_settings, - db_session=db_session, - ) + earliest_index_time + if index_attempt.from_beginning + else get_last_successful_attempt_time( + connector_id=db_connector.id, + credential_id=db_credential.id, + earliest_index=earliest_index_time, + search_settings=index_attempt.search_settings, + db_session=db_session, ) ) @@ -384,17 +384,22 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA return attempt -def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None: +def run_indexing_entrypoint( + index_attempt_id: int, connector_credential_pair_id: int, is_ee: bool = False +) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" + try: if is_ee: global_version.set_ee() # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix - IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) + IndexAttemptSingleton.set_cc_and_index_id( + index_attempt_id, connector_credential_pair_id + ) with Session(get_sqlalchemy_engine()) as db_session: # make sure that it is valid to run this indexing attempt + mark it diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 6e122678813..e746e43abae 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -93,9 +93,16 @@ def wrapped_fn( kwargs_for_build_name = kwargs or {} task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name) with Session(get_sqlalchemy_engine()) as db_session: - # mark the task as started + # register_task must come before fn = apply_async or else the task + # might run mark_task_start (and crash) before the task row exists + db_task = register_task(task_name, db_session) + task = fn(args, kwargs, *other_args, **other_kwargs) - register_task(task.id, task_name, db_session) + + # we update the celery task id for diagnostic purposes + # but it isn't currently used by any code + db_task.task_id = task.id + db_session.commit() return task diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 28abb481143..5fde7cb3da0 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -17,6 +17,7 @@ from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS +from danswer.configs.constants import DocumentSource from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import fetch_connector_credential_pairs @@ -46,7 +47,6 @@ from shared_configs.configs import LOG_LEVEL from shared_configs.configs import MODEL_SERVER_PORT - logger = setup_logger() # If the indexing dies, it's most likely due to resource constraints, @@ -67,6 +67,10 @@ def _should_create_new_indexing( ) -> bool: connector = cc_pair.connector + # don't kick off indexing for `NOT_APPLICABLE` sources + if connector.source == DocumentSource.NOT_APPLICABLE: + return False + # User can still manually create single indexing attempts via the UI for the # currently in use index if DISABLE_INDEX_UPDATE_ON_SWAP: @@ -337,6 +341,7 @@ def kickoff_indexing_jobs( run = secondary_client.submit( run_indexing_entrypoint, attempt.id, + attempt.connector_credential_pair_id, global_version.get_is_ee_version(), pure=False, ) @@ -344,6 +349,7 @@ def kickoff_indexing_jobs( run = client.submit( run_indexing_entrypoint, attempt.id, + attempt.connector_credential_pair_id, global_version.get_is_ee_version(), pure=False, ) diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 0690f08b759..1839b3a5f23 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -96,7 +96,17 @@ def load_personas_from_yaml( # Set specific overrides for image generation persona if persona.get("image_generation"): llm_model_version_override = "gpt-4o" - + + # Load Internet Search Tool. + if persona.get("internet_search"): + internet_search_tool = ( + db_session.query(ToolDBModel) + .filter(ToolDBModel.name == "InternetSearchTool") + .first() + ) + if internet_search_tool: + tool_ids.append(internet_search_tool.id) + existing_persona = ( db_session.query(Persona) .filter(Persona.name == persona["name"]) diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 6d12d68df08..97d5b9e7275 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel @@ -44,8 +45,26 @@ def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: return initial_dict +class StreamStopReason(Enum): + CONTEXT_LENGTH = "context_length" + CANCELLED = "cancelled" + + +class StreamStopInfo(BaseModel): + stop_reason: StreamStopReason + + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + data = super().model_dump(mode="json", *args, **kwargs) # type: ignore + data["stop_reason"] = self.stop_reason.name + return data + + class LLMRelevanceFilterResponse(BaseModel): - relevant_chunk_indices: list[int] + llm_selected_doc_indices: list[int] + + +class FinalUsedContextDocsResponse(BaseModel): + final_context_docs: list[LlmDoc] class RelevanceAnalysis(BaseModel): @@ -78,6 +97,16 @@ class CitationInfo(BaseModel): document_id: str +class AllCitations(BaseModel): + citations: list[CitationInfo] + + +# This is a mapping of the citation number to the document index within +# the result search doc set +class MessageSpecificCitations(BaseModel): + citation_map: dict[int, int] + + class MessageResponseIDInfo(BaseModel): user_message_id: int | None reserved_assistant_message_id: int @@ -123,7 +152,7 @@ class QAResponse(SearchResponse, DanswerAnswer): predicted_flow: QueryFlow predicted_search: SearchType eval_res_valid: bool | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None error_msg: str | None = None @@ -144,6 +173,7 @@ class CustomToolResponse(BaseModel): | ImageGenerationDisplay | CustomToolResponse | StreamingError + | StreamStopInfo ) diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index 9955b1d73c5..4c97d79efbc 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -19,11 +19,11 @@ personas: # Default number of chunks to include as context, set to 0 to disable retrieval # Remove the field to set to the system default number of chunks/tokens to pass to Gen AI # Each chunk is 512 tokens long - num_chunks: 50 + num_chunks: 20 # Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine # if the chunk is useful or not towards the latest user query # This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable - llm_relevance_filter: false + llm_relevance_filter: true # Enable/Disable usage of the LLM to extract query time filters including source type and time range filters llm_filter_extraction: true # Decay documents priority as they age, options are: @@ -44,11 +44,11 @@ personas: document_sets: [] icon_shape: 23013 icon_color: "#6FB1FF" - display_priority: 1 + display_priority: 0 is_visible: true - id: 1 - name: "General" + name: "General GPT" description: > Assistant with no access to documents. Chat with just the Large Language Model. prompts: @@ -60,16 +60,16 @@ personas: document_sets: [] icon_shape: 50910 icon_color: "#FF6F6F" - display_priority: 0 + display_priority: 1 is_visible: true - id: 2 - name: "Paraphrase" + name: "GPT Internet Search" description: > - Assistant that is heavily constrained and only provides exact quotes from Connected Sources. + Use this Assistant to search the Internet for you (via Bing) and getting the answer prompts: - - "Paraphrase" - num_chunks: 10 + - "InternetSearch" + num_chunks: 0 llm_relevance_filter: true llm_filter_extraction: true recency_bias: "auto" @@ -77,7 +77,8 @@ personas: icon_shape: 45519 icon_color: "#6FFF8D" display_priority: 2 - is_visible: false + is_visible: true + internet_search: true - id: 3 @@ -95,4 +96,4 @@ personas: icon_color: "#9B59B6" image_generation: true display_priority: 3 - is_visible: true + is_visible: false diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 2eea2cfc20f..fa13f245ccb 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,12 +7,15 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import FinalUsedContextDocsResponse from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import MessageResponseIDInfo +from danswer.chat.models import MessageSpecificCitations from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import BING_API_KEY @@ -85,6 +88,7 @@ ) from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse from danswer.tools.internet_search.internet_search_tool import InternetSearchTool +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool @@ -100,9 +104,9 @@ logger = setup_logger() -def translate_citations( +def _translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] -) -> dict[int, int]: +) -> MessageSpecificCitations: """Always cites the first instance of the document_id, assumes the db_docs are sorted in the order displayed in the UI""" doc_id_to_saved_doc_id_map: dict[str, int] = {} @@ -117,7 +121,7 @@ def translate_citations( citation.citation_num ] = doc_id_to_saved_doc_id_map[citation.document_id] - return citation_to_saved_doc_id_map + return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map) def _handle_search_tool_response_summary( @@ -239,11 +243,14 @@ def _get_force_search_settings( StreamingError | QADocsResponse | LLMRelevanceFilterResponse + | FinalUsedContextDocsResponse | ChatMessageDetail | DanswerAnswerPiece + | AllCitations | CitationInfo | ImageGenerationDisplay | CustomToolResponse + | MessageSpecificCitations | MessageResponseIDInfo ) ChatPacketStream = Iterator[ChatPacket] @@ -688,9 +695,13 @@ def stream_chat_message_objects( ) yield LLMRelevanceFilterResponse( - relevant_chunk_indices=llm_indices + llm_selected_doc_indices=llm_indices ) + elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: + yield FinalUsedContextDocsResponse( + final_context_docs=packet.response + ) elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( list[ImageGenerationResponse], packet.response @@ -743,12 +754,13 @@ def stream_chat_message_objects( # Post-LLM answer processing try: - db_citations = None + message_specific_citations: MessageSpecificCitations | None = None if reference_db_search_docs: - db_citations = translate_citations( + message_specific_citations = _translate_citations( citations_list=answer.citations, db_docs=reference_db_search_docs, ) + yield AllCitations(citations=answer.citations) # Saving Gen AI answer and responding with message info tool_name_to_tool_id: dict[str, int] = {} @@ -765,7 +777,9 @@ def stream_chat_message_objects( reference_docs=reference_db_search_docs, files=ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), - citations=db_citations, + citations=message_specific_citations.citation_map + if message_specific_citations + else None, error=None, tool_calls=[ ToolCall( diff --git a/backend/danswer/chat/prompts.yaml b/backend/danswer/chat/prompts.yaml index b3b9bae6467..4fe3c447495 100644 --- a/backend/danswer/chat/prompts.yaml +++ b/backend/danswer/chat/prompts.yaml @@ -107,3 +107,18 @@ prompts: directly from the documents. datetime_aware: true include_citations: true + + - name: "InternetSearch" + description: "Use this Assistant to search the Internet for you (via Bing) and getting the answer" + system: > + You are an intelligent AI agent designed to assist users by providing accurate and relevant information through internet searches. Your primary objectives are: + Information Retrieval: Search the internet to find reliable and up-to-date information based on user queries. Ensure that the sources you reference are credible and trustworthy. + Context Understanding: Analyze user questions to understand context and intent. Provide answers that are directly related to the user's needs, offering additional context when necessary. + Summarization: When presenting information, summarize findings clearly and concisely. Highlight key points and relevant details to enhance user understanding. + User Engagement: Maintain a friendly and engaging tone in your responses. Encourage users to ask follow-up questions or request further information. + Privacy and Safety: Respect user privacy and ensure that any personal information is handled securely. Avoid sharing sensitive or inappropriate content. + Continuous Learning: Adapt and improve your responses based on user interactions and feedback. Stay updated with the latest information and trends to provide the best assistance. + task: > + Search the internet for relevant information based on the user query. Provide a concise summary of the findings and include the sources of information. + datetime_aware: true + include_citations: true diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index f6b218c5f56..4b5109b5ee7 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -126,6 +126,7 @@ except ValueError: INDEX_BATCH_SIZE = 16 + # Below are intended to match the env variables names used by the official postgres docker image # https://hub.docker.com/_/postgres POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" @@ -149,6 +150,20 @@ except ValueError: POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT +REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true" +REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost" +REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) +REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or "" + +# Used for general redis things +REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) + +# Used by celery as broker and backend +REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) + +REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE") +REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "") + ##### # Connector Configs ##### diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 454412ff87e..b7c10ea36fb 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -83,8 +83,15 @@ # Stops streaming answers back to the UI if this pattern is seen: STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None -# The backend logic for this being True isn't fully supported yet -HARD_DELETE_CHATS = False +# Set this to "true" to hard delete chats +# This will make chats unviewable by admins after a user deletes them +# As opposed to soft deleting them, which just hides them from non-admin users +HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true" # Internet Search BING_API_KEY = os.environ.get("BING_API_KEY") or None + +# Enable in-house model for detecting connector-based filtering in queries +ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False) + +VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 64c162d7bef..e807f381e87 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -57,9 +57,12 @@ KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time" KV_SETTINGS_KEY = "danswer_settings" KV_CUSTOMER_UUID_KEY = "customer_uuid" +KV_INSTANCE_DOMAIN_KEY = "instance_domain" KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings" KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" +CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60 + class DocumentSource(str, Enum): # Special case, document passed in via Danswer APIs without specifying a source type @@ -166,3 +169,23 @@ class FileOrigin(str, Enum): class PostgresAdvisoryLocks(Enum): KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto() + + +class DanswerCeleryQueues: + VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator" + VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator" + VESPA_METADATA_SYNC = "vespa_metadata_sync" + CONNECTOR_DELETION = "connector_deletion" + + +class DanswerRedisLocks: + CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat" + MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" + + +class DanswerCeleryPriority(int, Enum): + HIGHEST = 0 + HIGH = auto() + MEDIUM = auto() + LOW = auto() + LOWEST = auto() diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 9e323c2b539..c9668cd8136 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -39,9 +39,13 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ") ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ") # Purely an optimization, memory limitation consideration -BATCH_SIZE_ENCODE_CHUNKS = 8 + +# User's set embedding batch size overrides the default encoding batch sizes +EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None + +BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8 # don't send over too many chunks at once, as sending too many could cause timeouts -BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512 +BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512 # For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 1 CROSS_ENCODER_RANGE_MIN = 0 @@ -51,33 +55,11 @@ # Generative AI Model Configs ##### -# If changing GEN_AI_MODEL_PROVIDER or GEN_AI_MODEL_VERSION from the default, -# be sure to use one that is LiteLLM compatible: -# https://litellm.vercel.app/docs/providers/azure#completion---using-env-variables -# The provider is the prefix before / in the model argument - -# Additionally Danswer supports GPT4All and custom request library based models -# Set GEN_AI_MODEL_PROVIDER to "custom" to use the custom requests approach -# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally -GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" -# If using Azure, it's the engine name, for example: Danswer +# NOTE: the 3 below should only be used for dev. +GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY") GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") - -# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need -# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION") -# If the Generative AI model requires an API key for access, otherwise can leave blank -GEN_AI_API_KEY = ( - os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) or None -) - -# API Base, such as (for Azure): https://danswer.openai.azure.com/ -GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None -# API Version, such as (for Azure): 2023-09-15-preview -GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None -# LiteLLM custom_llm_provider -GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None # Override the auto-detection of LLM max context length GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None diff --git a/backend/danswer/connectors/README.md b/backend/danswer/connectors/README.md index b50232fa256..ef6c63d2697 100644 --- a/backend/danswer/connectors/README.md +++ b/backend/danswer/connectors/README.md @@ -59,6 +59,8 @@ if __name__ == "__main__": latest_docs = test_connector.poll_source(one_day_ago, current) ``` +> Note: Be sure to set PYTHONPATH to danswer/backend before running the above main. + ### Additional Required Changes: #### Backend Changes @@ -68,17 +70,16 @@ if __name__ == "__main__": [here](https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/factory.py#L33) #### Frontend Changes -- Create the new connector directory and admin page under `danswer/web/src/app/admin/connectors/` -- Create the new icon, type, source, and filter changes -(refer to existing [PR](https://github.com/danswer-ai/danswer/pull/139)) +- Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/sources.ts#L59). +- Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/danswer-ai/danswer/blob/main/web/src/lib/connectors/connectors.ts#L79). #### Docs Changes Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the -connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs - +connector in Danswer. Then create a Pull Request in https://github.com/danswer-ai/danswer-docs. ### Before opening PR 1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the -new connector. -2. Be sure to run the linting/formatting, refer to the formatting and linting section in +new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page). +2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/danswer-ai/danswer/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass. +3. Be sure to run the linting/formatting, refer to the formatting and linting section in [CONTRIBUTING.md](https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md#formatting-and-linting) diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index b8dc967a3d9..78efce4ab98 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -7,7 +7,6 @@ from functools import lru_cache from typing import Any from typing import cast -from urllib.parse import urlparse import bs4 from atlassian import Confluence # type:ignore @@ -53,79 +52,6 @@ ) -def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: - """Sample - URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview - URL w/o page: https://danswer.atlassian.net/wiki/spaces/ASAM/overview - - wiki_base is https://danswer.atlassian.net/wiki - space is 1234abcd - page_id is 5678efgh - """ - parsed_url = urlparse(wiki_url) - wiki_base = ( - parsed_url.scheme - + "://" - + parsed_url.netloc - + parsed_url.path.split("/spaces")[0] - ) - - path_parts = parsed_url.path.split("/") - space = path_parts[3] - - page_id = path_parts[5] if len(path_parts) > 5 else "" - return wiki_base, space, page_id - - -def _extract_confluence_keys_from_datacenter_url(wiki_url: str) -> tuple[str, str, str]: - """Sample - URL w/ page https://danswer.ai/confluence/display/1234abcd/pages/5678efgh/overview - URL w/o page https://danswer.ai/confluence/display/1234abcd/overview - wiki_base is https://danswer.ai/confluence - space is 1234abcd - page_id is 5678efgh - """ - # /display/ is always right before the space and at the end of the base print() - DISPLAY = "/display/" - PAGE = "/pages/" - - parsed_url = urlparse(wiki_url) - wiki_base = ( - parsed_url.scheme - + "://" - + parsed_url.netloc - + parsed_url.path.split(DISPLAY)[0] - ) - space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] - page_id = "" - if (content := parsed_url.path.split(PAGE)) and len(content) > 1: - page_id = content[1] - return wiki_base, space, page_id - - -def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: - is_confluence_cloud = ( - ".atlassian.net/wiki/spaces/" in wiki_url - or ".jira.com/wiki/spaces/" in wiki_url - ) - - try: - if is_confluence_cloud: - wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url( - wiki_url - ) - else: - wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( - wiki_url - ) - except Exception as e: - error_msg = f"Not a valid Confluence Wiki Link, unable to extract wiki base, space, and page id. Exception: {e}" - logger.error(error_msg) - raise ValueError(error_msg) - - return wiki_base, space, page_id, is_confluence_cloud - - @lru_cache() def _get_user(user_id: str, confluence_client: Confluence) -> str: """Get Confluence Display Name based on the account-id or userkey value @@ -372,7 +298,10 @@ def _fetch_single_depth_child_pages( class ConfluenceConnector(LoadConnector, PollConnector): def __init__( self, - wiki_page_url: str, + wiki_base: str, + space: str, + is_cloud: bool, + page_id: str = "", index_recursively: bool = True, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, @@ -386,15 +315,15 @@ def __init__( self.labels_to_skip = set(labels_to_skip) self.recursive_indexer: RecursiveIndexer | None = None self.index_recursively = index_recursively - ( - self.wiki_base, - self.space, - self.page_id, - self.is_cloud, - ) = extract_confluence_keys_from_url(wiki_page_url) - self.space_level_scan = False + # Remove trailing slash from wiki_base if present + self.wiki_base = wiki_base.rstrip("/") + self.space = space + self.page_id = page_id + self.is_cloud = is_cloud + + self.space_level_scan = False self.confluence_client: Confluence | None = None if self.page_id is None or self.page_id == "": @@ -414,7 +343,6 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None username=username if self.is_cloud else None, password=access_token if self.is_cloud else None, token=access_token if not self.is_cloud else None, - cloud=self.is_cloud, ) return None @@ -866,7 +794,13 @@ def poll_source( if __name__ == "__main__": - connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"]) + connector = ConfluenceConnector( + wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"], + space=os.environ["CONFLUENCE_TEST_SPACE"], + is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true", + page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), + index_recursively=True, + ) connector.load_credentials( { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], diff --git a/backend/danswer/connectors/confluence/rate_limit_handler.py b/backend/danswer/connectors/confluence/rate_limit_handler.py index 8755b78f3f4..822badb9b99 100644 --- a/backend/danswer/connectors/confluence/rate_limit_handler.py +++ b/backend/danswer/connectors/confluence/rate_limit_handler.py @@ -23,7 +23,7 @@ class ConfluenceRateLimitError(Exception): def make_confluence_call_handle_rate_limit(confluence_call: F) -> F: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: - max_retries = 10 + max_retries = 5 starting_delay = 5 backoff = 2 max_delay = 600 @@ -32,17 +32,24 @@ def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: try: return confluence_call(*args, **kwargs) except HTTPError as e: + # Check if the response or headers are None to avoid potential AttributeError + if e.response is None or e.response.headers is None: + logger.warning("HTTPError with `None` as response or as headers") + raise e + + retry_after_header = e.response.headers.get("Retry-After") if ( e.response.status_code == 429 or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() ): retry_after = None - try: - retry_after = int(e.response.headers.get("Retry-After")) - except (ValueError, TypeError): - pass + if retry_after_header is not None: + try: + retry_after = int(retry_after_header) + except ValueError: + pass - if retry_after: + if retry_after is not None: logger.warning( f"Rate limit hit. Retrying after {retry_after} seconds..." ) diff --git a/backend/danswer/connectors/danswer_jira/connector.py b/backend/danswer/connectors/danswer_jira/connector.py index 9a8fbb31501..e3562f3a45c 100644 --- a/backend/danswer/connectors/danswer_jira/connector.py +++ b/backend/danswer/connectors/danswer_jira/connector.py @@ -45,10 +45,15 @@ def extract_jira_project(url: str) -> tuple[str, str]: return jira_base, jira_project -def extract_text_from_content(content: dict) -> str: +def extract_text_from_adf(adf: dict | None) -> str: + """Extracts plain text from Atlassian Document Format: + https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/ + + WARNING: This function is incomplete and will e.g. skip lists! + """ texts = [] - if "content" in content: - for block in content["content"]: + if adf is not None and "content" in adf: + for block in adf["content"]: if "content" in block: for item in block["content"]: if item["type"] == "text": @@ -72,18 +77,15 @@ def _get_comment_strs( comment_strs = [] for comment in jira.fields.comment.comments: try: - if hasattr(comment, "body"): - body_text = extract_text_from_content(comment.raw["body"]) - elif hasattr(comment, "raw"): - body = comment.raw.get("body", "No body content available") - body_text = ( - extract_text_from_content(body) if isinstance(body, dict) else body - ) - else: - body_text = "No body attribute found" + body_text = ( + comment.body + if JIRA_API_VERSION == "2" + else extract_text_from_adf(comment.raw["body"]) + ) if ( hasattr(comment, "author") + and hasattr(comment.author, "emailAddress") and comment.author.emailAddress in comment_email_blacklist ): continue # Skip adding comment if author's email is in blacklist @@ -126,11 +128,14 @@ def fetch_jira_issues_batch( ) continue + description = ( + jira.fields.description + if JIRA_API_VERSION == "2" + else extract_text_from_adf(jira.raw["fields"]["description"]) + ) comments = _get_comment_strs(jira, comment_email_blacklist) - semantic_rep = ( - f"{jira.fields.description}\n" - if jira.fields.description - else "" + "\n".join([f"Comment: {comment}" for comment in comments]) + semantic_rep = f"{description}\n" + "\n".join( + [f"Comment: {comment}" for comment in comments if comment] ) page_url = f"{jira_client.client_info()}/browse/{jira.key}" diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 6c5501734b0..83d0af2c12e 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -23,7 +23,7 @@ from danswer.file_processing.extract_file_text import get_file_ext from danswer.file_processing.extract_file_text import is_text_file_extension from danswer.file_processing.extract_file_text import load_files_from_zip -from danswer.file_processing.extract_file_text import pdf_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.extract_file_text import read_text_file from danswer.file_store.file_store import get_default_file_store from danswer.utils.logger import setup_logger @@ -75,7 +75,7 @@ def _process_file( # Using the PDF reader function directly to pass in password cleanly elif extension == ".pdf": - file_content_raw = pdf_to_text(file=file, pdf_pass=pdf_pass) + file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass) else: file_content_raw = extract_file_text( diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 40a9b73432f..80674b5a37d 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -41,8 +41,8 @@ from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.file_processing.extract_file_text import docx_to_text -from danswer.file_processing.extract_file_text import pdf_to_text from danswer.file_processing.extract_file_text import pptx_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger @@ -62,6 +62,8 @@ class GDriveMimeType(str, Enum): POWERPOINT = ( "application/vnd.openxmlformats-officedocument.presentationml.presentation" ) + PLAIN_TEXT = "text/plain" + MARKDOWN = "text/markdown" GoogleDriveFileType = dict[str, Any] @@ -316,25 +318,29 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str: GDriveMimeType.PPT.value, GDriveMimeType.SPREADSHEET.value, ]: - export_mime_type = "text/plain" - if mime_type == GDriveMimeType.SPREADSHEET.value: - export_mime_type = "text/csv" - elif mime_type == GDriveMimeType.PPT.value: - export_mime_type = "text/plain" - - response = ( + export_mime_type = ( + "text/plain" + if mime_type != GDriveMimeType.SPREADSHEET.value + else "text/csv" + ) + return ( service.files() .export(fileId=file["id"], mimeType=export_mime_type) .execute() + .decode("utf-8") ) - return response.decode("utf-8") - + elif mime_type in [ + GDriveMimeType.PLAIN_TEXT.value, + GDriveMimeType.MARKDOWN.value, + ]: + return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") elif mime_type == GDriveMimeType.WORD_DOC.value: response = service.files().get_media(fileId=file["id"]).execute() return docx_to_text(file=io.BytesIO(response)) elif mime_type == GDriveMimeType.PDF.value: response = service.files().get_media(fileId=file["id"]).execute() - return pdf_to_text(file=io.BytesIO(response)) + text, _ = read_pdf_file(file=io.BytesIO(response)) + return text elif mime_type == GDriveMimeType.POWERPOINT.value: response = service.files().get_media(fileId=file["id"]).execute() return pptx_to_text(file=io.BytesIO(response)) diff --git a/backend/danswer/connectors/notion/connector.py b/backend/danswer/connectors/notion/connector.py index fd607e4f97a..7878434da04 100644 --- a/backend/danswer/connectors/notion/connector.py +++ b/backend/danswer/connectors/notion/connector.py @@ -237,6 +237,14 @@ def _read_blocks( ) continue + if result_type == "external_object_instance_page": + logger.warning( + f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': " + f"Notion API does not currently support reading external blocks (as of 24/07/03) " + f"(discussion: https://github.com/danswer-ai/danswer/issues/1761)" + ) + continue + cur_result_text_arr = [] if "rich_text" in result_obj: for rich_text in result_obj["rich_text"]: diff --git a/backend/danswer/connectors/productboard/connector.py b/backend/danswer/connectors/productboard/connector.py index 9ef301aa76d..c7a2d45cae8 100644 --- a/backend/danswer/connectors/productboard/connector.py +++ b/backend/danswer/connectors/productboard/connector.py @@ -98,6 +98,15 @@ def _get_features(self) -> Generator[Document, None, None]: owner = self._get_owner_email(feature) experts = [BasicExpertInfo(email=owner)] if owner else None + metadata: dict[str, str | list[str]] = {} + entity_type = feature.get("type", "feature") + if entity_type: + metadata["entity_type"] = str(entity_type) + + status = feature.get("status", {}).get("name") + if status: + metadata["status"] = str(status) + yield Document( id=feature["id"], sections=[ @@ -110,10 +119,7 @@ def _get_features(self) -> Generator[Document, None, None]: source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(feature["updatedAt"]), primary_owners=experts, - metadata={ - "entity_type": feature["type"], - "status": feature["status"]["name"], - }, + metadata=metadata, ) def _get_components(self) -> Generator[Document, None, None]: @@ -174,6 +180,12 @@ def _get_objectives(self) -> Generator[Document, None, None]: owner = self._get_owner_email(objective) experts = [BasicExpertInfo(email=owner)] if owner else None + metadata: dict[str, str | list[str]] = { + "entity_type": "objective", + } + if objective.get("state"): + metadata["state"] = str(objective["state"]) + yield Document( id=objective["id"], sections=[ @@ -186,10 +198,7 @@ def _get_objectives(self) -> Generator[Document, None, None]: source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(objective["updatedAt"]), primary_owners=experts, - metadata={ - "entity_type": "release", - "state": objective["state"], - }, + metadata=metadata, ) def _is_updated_at_out_of_time_range( diff --git a/backend/danswer/connectors/sharepoint/connector.py b/backend/danswer/connectors/sharepoint/connector.py index b66c010d77f..e74dcbf7edd 100644 --- a/backend/danswer/connectors/sharepoint/connector.py +++ b/backend/danswer/connectors/sharepoint/connector.py @@ -25,7 +25,6 @@ from danswer.file_processing.extract_file_text import extract_file_text from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -137,7 +136,7 @@ def _populate_sitedata_sites(self) -> None: .execute_query() ] else: - sites = self.graph_client.sites.get().execute_query() + sites = self.graph_client.sites.get_all().execute_query() self.site_data = [ SiteData(url=None, folder=None, sites=sites, driveitems=[]) ] diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 6c451389932..975653f5f61 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -29,6 +29,7 @@ from danswer.connectors.slack.utils import SlackTextCleaner from danswer.utils.logger import setup_logger + logger = setup_logger() diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 6e76e404acd..bb1f64efdfe 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -1,6 +1,8 @@ import io import ipaddress import socket +from datetime import datetime +from datetime import timezone from enum import Enum from typing import Any from typing import cast @@ -27,7 +29,7 @@ from danswer.connectors.interfaces import LoadConnector from danswer.connectors.models import Document from danswer.connectors.models import Section -from danswer.file_processing.extract_file_text import pdf_to_text +from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.html_utils import web_html_cleanup from danswer.utils.logger import setup_logger from danswer.utils.sitemap import list_pages_for_site @@ -85,7 +87,8 @@ def check_internet_connection(url: str) -> None: response = requests.get(url, timeout=3) response.raise_for_status() except requests.exceptions.HTTPError as e: - status_code = e.response.status_code + # Extract status code from the response, defaulting to -1 if response is None + status_code = e.response.status_code if e.response is not None else -1 error_msg = { 400: "Bad Request", 401: "Unauthorized", @@ -202,6 +205,15 @@ def _read_urls_file(location: str) -> list[str]: return urls +def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None: + try: + return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( + tzinfo=timezone.utc + ) + except (ValueError, TypeError): + return None + + class WebConnector(LoadConnector): def __init__( self, @@ -284,7 +296,10 @@ def load_from_state(self) -> GenerateDocumentsOutput: if current_url.split(".")[-1] == "pdf": # PDF files are not checked for links response = requests.get(current_url) - page_text = pdf_to_text(file=io.BytesIO(response.content)) + page_text, metadata = read_pdf_file( + file=io.BytesIO(response.content) + ) + last_modified = response.headers.get("Last-Modified") doc_batch.append( Document( @@ -292,13 +307,23 @@ def load_from_state(self) -> GenerateDocumentsOutput: sections=[Section(link=current_url, text=page_text)], source=DocumentSource.WEB, semantic_identifier=current_url.split("/")[-1], - metadata={}, + metadata=metadata, + doc_updated_at=_get_datetime_from_last_modified_header( + last_modified + ) + if last_modified + else None, ) ) continue page = context.new_page() page_response = page.goto(current_url) + last_modified = ( + page_response.header_value("Last-Modified") + if page_response + else None + ) final_page = page.url if final_page != current_url: logger.info(f"Redirected to {final_page}") @@ -334,6 +359,11 @@ def load_from_state(self) -> GenerateDocumentsOutput: source=DocumentSource.WEB, semantic_identifier=parsed_html.title or current_url, metadata={}, + doc_updated_at=_get_datetime_from_last_modified_header( + last_modified + ) + if last_modified + else None, ) ) diff --git a/backend/danswer/connectors/zendesk/connector.py b/backend/danswer/connectors/zendesk/connector.py index b6d4220b9ce..f85f2efff57 100644 --- a/backend/danswer/connectors/zendesk/connector.py +++ b/backend/danswer/connectors/zendesk/connector.py @@ -3,6 +3,7 @@ import requests from retry import retry from zenpy import Zenpy # type: ignore +from zenpy.lib.api_objects import Ticket # type: ignore from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE @@ -59,10 +60,15 @@ def __init__(self) -> None: class ZendeskConnector(LoadConnector, PollConnector): - def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: + def __init__( + self, + batch_size: int = INDEX_BATCH_SIZE, + content_type: str = "articles", + ) -> None: self.batch_size = batch_size self.zendesk_client: Zenpy | None = None self.content_tags: dict[str, str] = {} + self.content_type = content_type @retry(tries=3, delay=2, backoff=2) def _set_content_tags( @@ -122,16 +128,86 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: return self.poll_source(None, None) + def _ticket_to_document(self, ticket: Ticket) -> Document: + if self.zendesk_client is None: + raise ZendeskClientNotSetUpError() + + owner = None + if ticket.requester and ticket.requester.name and ticket.requester.email: + owner = [ + BasicExpertInfo( + display_name=ticket.requester.name, email=ticket.requester.email + ) + ] + update_time = time_str_to_utc(ticket.updated_at) if ticket.updated_at else None + + metadata: dict[str, str | list[str]] = {} + if ticket.status is not None: + metadata["status"] = ticket.status + if ticket.priority is not None: + metadata["priority"] = ticket.priority + if ticket.tags: + metadata["tags"] = ticket.tags + if ticket.type is not None: + metadata["ticket_type"] = ticket.type + + # Fetch comments for the ticket + comments = self.zendesk_client.tickets.comments(ticket=ticket) + + # Combine all comments into a single text + comments_text = "\n\n".join( + [ + f"Comment{f' by {comment.author.name}' if comment.author and comment.author.name else ''}" + f"{f' at {comment.created_at}' if comment.created_at else ''}:\n{comment.body}" + for comment in comments + if comment.body + ] + ) + + # Combine ticket description and comments + description = ( + ticket.description + if hasattr(ticket, "description") and ticket.description + else "" + ) + full_text = f"Ticket Description:\n{description}\n\nComments:\n{comments_text}" + + # Extract subdomain from ticket.url + subdomain = ticket.url.split("//")[1].split(".zendesk.com")[0] + + # Build the html url for the ticket + ticket_url = f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.id}" + + return Document( + id=f"zendesk_ticket_{ticket.id}", + sections=[Section(link=ticket_url, text=full_text)], + source=DocumentSource.ZENDESK, + semantic_identifier=f"Ticket #{ticket.id}: {ticket.subject or 'No Subject'}", + doc_updated_at=update_time, + primary_owners=owner, + metadata=metadata, + ) + def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: if self.zendesk_client is None: raise ZendeskClientNotSetUpError() + if self.content_type == "articles": + yield from self._poll_articles(start) + elif self.content_type == "tickets": + yield from self._poll_tickets(start) + else: + raise ValueError(f"Unsupported content_type: {self.content_type}") + + def _poll_articles( + self, start: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: articles = ( - self.zendesk_client.help_center.articles(cursor_pagination=True) + self.zendesk_client.help_center.articles(cursor_pagination=True) # type: ignore if start is None - else self.zendesk_client.help_center.articles.incremental( + else self.zendesk_client.help_center.articles.incremental( # type: ignore start_time=int(start) ) ) @@ -155,9 +231,43 @@ def poll_source( if doc_batch: yield doc_batch + def _poll_tickets( + self, start: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + if self.zendesk_client is None: + raise ZendeskClientNotSetUpError() + + ticket_generator = self.zendesk_client.tickets.incremental(start_time=start) + + while True: + doc_batch = [] + for _ in range(self.batch_size): + try: + ticket = next(ticket_generator) + + # Check if the ticket status is deleted and skip it if so + if ticket.status == "deleted": + continue + + doc_batch.append(self._ticket_to_document(ticket)) + + if len(doc_batch) >= self.batch_size: + yield doc_batch + doc_batch.clear() + + except StopIteration: + # No more tickets to process + if doc_batch: + yield doc_batch + return + + if doc_batch: + yield doc_batch + if __name__ == "__main__": import os + import time connector = ZendeskConnector() diff --git a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py index 732be8df9db..9e1c171ee4f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_buttons.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_buttons.py @@ -11,6 +11,7 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks from danswer.danswerbot.slack.blocks import get_document_feedback_blocks @@ -87,6 +88,8 @@ def handle_generate_answer_button( message_ts = req.payload["message"]["ts"] thread_ts = req.payload["container"]["thread_ts"] user_id = req.payload["user"]["id"] + expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={}) + email = expert_info.email if expert_info else None if not thread_ts: raise ValueError("Missing thread_ts in the payload") @@ -125,6 +128,7 @@ def handle_generate_answer_button( msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), sender=user_id or None, + email=email or None, bypass_filters=True, is_bot_msg=False, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 2edbd973553..cce45331ee7 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -21,6 +21,7 @@ from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig +from danswer.db.users import add_non_web_user_if_not_exists from danswer.utils.logger import setup_logger from shared_configs.configs import SLACK_CHANNEL_ID @@ -209,6 +210,9 @@ def handle_message( logger.error(f"Was not able to react to user message due to: {e}") with Session(get_sqlalchemy_engine()) as db_session: + if message_info.email: + add_non_web_user_if_not_exists(message_info.email, db_session) + # first check if we need to respond with a standard answer used_standard_answer = handle_standard_answers( message_info=message_info, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index e3a78917a76..7057d7c2e4b 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -38,6 +38,7 @@ from danswer.db.models import SlackBotResponseType from danswer.db.persona import fetch_persona_by_id from danswer.db.search_settings import get_current_search_settings +from danswer.db.users import get_user_by_email from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) @@ -99,6 +100,12 @@ def handle_regular_answer( messages = message_info.thread_messages message_ts_to_respond_to = message_info.msg_to_respond is_bot_msg = message_info.is_bot_msg + user = None + if message_info.is_bot_dm: + if message_info.email: + engine = get_sqlalchemy_engine() + with Session(engine) as db_session: + user = get_user_by_email(message_info.email, db_session) document_set_names: list[str] | None = None persona = slack_bot_config.persona if slack_bot_config else None @@ -185,7 +192,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non # This also handles creating the query event in postgres answer = get_search_answer( query_req=new_message_request, - user=None, + user=user, max_document_tokens=max_document_tokens, max_history_tokens=max_history_tokens, db_session=db_session, @@ -412,7 +419,7 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non ) # Get the chunks fed to the LLM only, then fill with other docs - llm_doc_inds = answer.llm_chunks_indices or [] + llm_doc_inds = answer.llm_selected_doc_indices or [] llm_docs = [top_docs[i] for i in llm_doc_inds] remaining_docs = [ doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index c59f4caf1aa..63f8bcfcd9c 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -13,6 +13,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER +from danswer.connectors.slack.utils import expert_info_from_slack_id from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID @@ -256,6 +257,11 @@ def build_request_details( tagged = event.get("type") == "app_mention" message_ts = event.get("ts") thread_ts = event.get("thread_ts") + sender = event.get("user") or None + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None msg = remove_danswer_bot_tag(msg, client=client.web_client) @@ -286,7 +292,8 @@ def build_request_details( channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), - sender=event.get("user") or None, + sender=sender, + email=email, bypass_filters=tagged, is_bot_msg=False, is_bot_dm=event.get("channel_type") == "im", @@ -296,6 +303,10 @@ def build_request_details( channel = req.payload["channel_id"] msg = req.payload["text"] sender = req.payload["user_id"] + expert_info = expert_info_from_slack_id( + sender, client.web_client, user_cache={} + ) + email = expert_info.email if expert_info else None single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) @@ -305,6 +316,7 @@ def build_request_details( msg_to_respond=None, thread_to_respond=None, sender=sender, + email=email, bypass_filters=True, is_bot_msg=True, is_bot_dm=False, diff --git a/backend/danswer/danswerbot/slack/models.py b/backend/danswer/danswerbot/slack/models.py index e4521a759a7..6394eab562d 100644 --- a/backend/danswer/danswerbot/slack/models.py +++ b/backend/danswer/danswerbot/slack/models.py @@ -9,6 +9,7 @@ class SlackMessageInfo(BaseModel): msg_to_respond: str | None thread_to_respond: str | None sender: str | None + email: str | None bypass_filters: bool # User has tagged @DanswerBot is_bot_msg: bool # User is using /DanswerBot is_bot_dm: bool # User is direct messaging to DanswerBot diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 161fdc8f10b..6d150b106cb 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]: get_default_admin_user_emails_fn: Callable[ [], list[str] ] = fetch_versioned_implementation_with_fallback( - "danswer.auth.users", "get_default_admin_user_emails_", lambda: [] + "danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]() ) return get_default_admin_user_emails_fn() diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 3cb991dd43b..8485bb4f0ae 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -3,7 +3,6 @@ from datetime import timedelta from uuid import UUID -from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func @@ -87,29 +86,57 @@ def get_chat_sessions_by_slack_thread_id( return db_session.scalars(stmt).all() -def get_first_messages_for_chat_sessions( - chat_session_ids: list[int], db_session: Session +def get_valid_messages_from_query_sessions( + chat_session_ids: list[int], + db_session: Session, ) -> dict[int, str]: - subquery = ( - select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id")) + user_message_subquery = ( + select( + ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id") + ) .where( - and_( - ChatMessage.chat_session_id.in_(chat_session_ids), - ChatMessage.message_type == MessageType.USER, # Select USER messages - ) + ChatMessage.chat_session_id.in_(chat_session_ids), + ChatMessage.message_type == MessageType.USER, + ) + .group_by(ChatMessage.chat_session_id) + .subquery() + ) + + assistant_message_subquery = ( + select( + ChatMessage.chat_session_id, + func.min(ChatMessage.id).label("assistant_msg_id"), + ) + .where( + ChatMessage.chat_session_id.in_(chat_session_ids), + ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(ChatMessage.chat_session_id) .subquery() ) - query = select(ChatMessage.chat_session_id, ChatMessage.message).join( - subquery, - (ChatMessage.chat_session_id == subquery.c.chat_session_id) - & (ChatMessage.id == subquery.c.min_id), + query = ( + select(ChatMessage.chat_session_id, ChatMessage.message) + .join( + user_message_subquery, + ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id, + ) + .join( + assistant_message_subquery, + ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id, + ) + .join( + ChatMessage__SearchDoc, + ChatMessage__SearchDoc.chat_message_id + == assistant_message_subquery.c.assistant_msg_id, + ) + .where(ChatMessage.id == user_message_subquery.c.user_msg_id) ) first_messages = db_session.execute(query).all() - return dict([(row.chat_session_id, row.message) for row in first_messages]) + logger.info(f"Retrieved {len(first_messages)} first messages with documents") + + return {row.chat_session_id: row.message for row in first_messages} def get_chat_sessions_by_user( @@ -253,6 +280,13 @@ def delete_chat_session( db_session: Session, hard_delete: bool = HARD_DELETE_CHATS, ) -> None: + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, user_id=user_id, db_session=db_session + ) + + if chat_session.deleted: + raise ValueError("Cannot delete an already deleted chat session") + if hard_delete: delete_messages_and_files_from_chat_session(chat_session_id, db_session) db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id)) diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index a6848232caf..004b5a754e4 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -98,6 +98,18 @@ def get_connector_credential_pairs( return list(results.all()) +def add_deletion_failure_message( + db_session: Session, + cc_pair_id: int, + failure_message: str, +) -> None: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if not cc_pair: + return + cc_pair.deletion_failure_message = failure_message + db_session.commit() + + def get_cc_pair_groups_for_ids( db_session: Session, cc_pair_ids: list[int], @@ -159,6 +171,7 @@ def get_connector_credential_pair_from_id( def get_last_successful_attempt_time( connector_id: int, credential_id: int, + earliest_index: float, search_settings: SearchSettings, db_session: Session, ) -> float: @@ -172,7 +185,7 @@ def get_last_successful_attempt_time( connector_credential_pair is None or connector_credential_pair.last_successful_index_time is None ): - return 0.0 + return earliest_index return connector_credential_pair.last_successful_index_time.timestamp() @@ -192,11 +205,9 @@ def get_last_successful_attempt_time( .order_by(IndexAttempt.time_started.desc()) .first() ) + if not attempt or not attempt.time_started: - connector = fetch_connector_by_id(connector_id, db_session) - if connector and connector.indexing_start: - return connector.indexing_start.timestamp() - return 0.0 + return earliest_index return attempt.time_started.timestamp() @@ -335,9 +346,13 @@ def add_credential_to_connector( raise HTTPException(status_code=404, detail="Connector does not exist") if credential is None: + error_msg = ( + f"Credential {credential_id} does not exist or does not belong to user" + ) + logger.error(error_msg) raise HTTPException( status_code=401, - detail="Credential does not exist or does not belong to user", + detail=error_msg, ) existing_association = ( @@ -351,7 +366,7 @@ def add_credential_to_connector( if existing_association is not None: return StatusResponse( success=False, - message=f"Connector already has Credential {credential_id}", + message=f"Connector {connector_id} already has Credential {credential_id}", data=connector_id, ) @@ -375,8 +390,8 @@ def add_credential_to_connector( db_session.commit() return StatusResponse( - success=False, - message=f"Connector already has Credential {credential_id}", + success=True, + message=f"Creating new association between Connector {connector_id} and Credential {credential_id}", data=association.id, ) diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 77ea4e3dd9d..92b093ab587 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -3,6 +3,7 @@ from collections.abc import Generator from collections.abc import Sequence from datetime import datetime +from datetime import timezone from uuid import UUID from sqlalchemy import and_ @@ -10,6 +11,7 @@ from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import or_ +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine.util import TransactionalContext @@ -38,6 +40,68 @@ def check_docs_exist(db_session: Session) -> bool: return result.scalar() or False +def count_documents_by_needs_sync(session: Session) -> int: + """Get the count of all documents where: + 1. last_modified is newer than last_synced + 2. last_synced is null (meaning we've never synced) + + This function executes the query and returns the count of + documents matching the criteria.""" + + count = ( + session.query(func.count()) + .select_from(DbDocument) + .filter( + or_( + DbDocument.last_modified > DbDocument.last_synced, + DbDocument.last_synced.is_(None), + ) + ) + .scalar() + ) + + return count + + +def construct_document_select_for_connector_credential_pair_by_needs_sync( + connector_id: int, credential_id: int +) -> Select: + initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + ) + ) + + stmt = ( + select(DbDocument) + .where( + DbDocument.id.in_(initial_doc_ids_stmt), + or_( + DbDocument.last_modified + > DbDocument.last_synced, # last_modified is newer than last_synced + DbDocument.last_synced.is_(None), # never synced + ), + ) + .distinct() + ) + + return stmt + + +def construct_document_select_for_connector_credential_pair( + connector_id: int, credential_id: int | None = None +) -> Select: + initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + ) + ) + stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() + return stmt + + def get_documents_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, limit: int | None = None ) -> Sequence[DbDocument]: @@ -108,7 +172,29 @@ def get_document_cnts_for_cc_pairs( return db_session.execute(stmt).all() # type: ignore -def get_acccess_info_for_documents( +def get_access_info_for_document( + db_session: Session, + document_id: str, +) -> tuple[str, list[UUID | None], bool] | None: + """Gets access info for a single document by calling the get_access_info_for_documents function + and passing a list with a single document ID. + + Args: + db_session (Session): The database session to use. + document_id (str): The document ID to fetch access info for. + + Returns: + Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs, + and a boolean indicating if the document is globally public, or None if no results are found. + """ + results = get_access_info_for_documents(db_session, [document_id]) + if not results: + return None + + return results[0] + + +def get_access_info_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[UUID | None], bool]]: @@ -173,6 +259,7 @@ def upsert_documents( semantic_id=doc.semantic_identifier, link=doc.first_link, doc_updated_at=None, # this is intentional + last_modified=datetime.now(timezone.utc), primary_owners=doc.primary_owners, secondary_owners=doc.secondary_owners, ) @@ -214,7 +301,7 @@ def upsert_document_by_connector_credential_pair( db_session.commit() -def update_docs_updated_at( +def update_docs_updated_at__no_commit( ids_to_new_updated_at: dict[str, datetime], db_session: Session, ) -> None: @@ -226,6 +313,28 @@ def update_docs_updated_at( for document in documents_to_update: document.doc_updated_at = ids_to_new_updated_at[document.id] + +def update_docs_last_modified__no_commit( + document_ids: list[str], + db_session: Session, +) -> None: + documents_to_update = ( + db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all() + ) + + now = datetime.now(timezone.utc) + for doc in documents_to_update: + doc.last_modified = now + + +def mark_document_as_synced(document_id: str, db_session: Session) -> None: + stmt = select(DbDocument).where(DbDocument.id == document_id) + doc = db_session.scalar(stmt) + if doc is None: + raise ValueError(f"No document with ID: {document_id}") + + # update last_synced + doc.last_synced = datetime.now(timezone.utc) db_session.commit() @@ -379,3 +488,12 @@ def get_documents_by_cc_pair( .filter(ConnectorCredentialPair.id == cc_pair_id) .all() ) + + +def get_document( + document_id: str, + db_session: Session, +) -> DbDocument | None: + stmt = select(DbDocument).where(DbDocument.id == document_id) + doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() + return doc diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 2de61a491f9..4a37f8bdced 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -248,6 +248,10 @@ def update_document_set( document_set_update_request: DocumentSetUpdateRequest, user: User | None = None, ) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]: + """If successful, this sets document_set_row.is_up_to_date = False. + That will be processed via Celery in check_for_vespa_sync_task + and trigger a long running background sync to Vespa. + """ if not document_set_update_request.cc_pair_ids: # It's cc-pairs in actuality but the UI displays this error raise ValueError("Cannot create a document set with no Connectors") @@ -519,42 +523,135 @@ def fetch_documents_for_document_set_paginated( return documents, documents[-1].id if documents else None -def fetch_document_sets_for_documents( - document_ids: list[str], - db_session: Session, -) -> Sequence[tuple[str, list[str]]]: - """Gives back a list of (document_id, list[document_set_names]) tuples""" +def construct_document_select_by_docset( + document_set_id: int, + current_only: bool = True, +) -> Select: + """This returns a statement that should be executed using + .yield_per() to minimize overhead. The primary consumers of this function + are background processing task generators.""" + stmt = ( - select(Document.id, func.array_agg(DocumentSetDBModel.name)) + select(Document) .join( - DocumentSet__ConnectorCredentialPair, - DocumentSetDBModel.id - == DocumentSet__ConnectorCredentialPair.document_set_id, + DocumentByConnectorCredentialPair, + DocumentByConnectorCredentialPair.id == Document.id, ) .join( ConnectorCredentialPair, - ConnectorCredentialPair.id - == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id, + and_( + ConnectorCredentialPair.connector_id + == DocumentByConnectorCredentialPair.connector_id, + ConnectorCredentialPair.credential_id + == DocumentByConnectorCredentialPair.credential_id, + ), + ) + .join( + DocumentSet__ConnectorCredentialPair, + DocumentSet__ConnectorCredentialPair.connector_credential_pair_id + == ConnectorCredentialPair.id, ) .join( + DocumentSetDBModel, + DocumentSetDBModel.id + == DocumentSet__ConnectorCredentialPair.document_set_id, + ) + .where(DocumentSetDBModel.id == document_set_id) + .order_by(Document.id) + ) + + if current_only: + stmt = stmt.where( + DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712 + ) + + stmt = stmt.distinct() + return stmt + + +def fetch_document_set_for_document( + document_id: str, + db_session: Session, +) -> list[str]: + """ + Fetches the document set names for a single document ID. + + :param document_id: The ID of the document to fetch sets for. + :param db_session: The SQLAlchemy session to use for the query. + :return: A list of document set names, or None if no result is found. + """ + result = fetch_document_sets_for_documents([document_id], db_session) + if not result: + return [] + + return result[0][1] + + +def fetch_document_sets_for_documents( + document_ids: list[str], + db_session: Session, +) -> Sequence[tuple[str, list[str]]]: + """Gives back a list of (document_id, list[document_set_names]) tuples""" + + """Building subqueries""" + # NOTE: have to build these subqueries first in order to guarantee that we get one + # returned row for each specified document_id. Basically, we want to do the filters first, + # then the outer joins. + + # don't include CC pairs that are being deleted + # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them + # as we can assume their document sets are no longer relevant + valid_cc_pairs_subquery = aliased( + ConnectorCredentialPair, + select(ConnectorCredentialPair) + .where( + ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING + ) # noqa: E712 + .subquery(), + ) + + valid_document_set__cc_pairs_subquery = aliased( + DocumentSet__ConnectorCredentialPair, + select(DocumentSet__ConnectorCredentialPair) + .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712 + .subquery(), + ) + """End building subqueries""" + + stmt = ( + select( + Document.id, + func.coalesce( + func.array_remove(func.array_agg(DocumentSetDBModel.name), None), [] + ).label("document_set_names"), + ) + # Here we select document sets by relation: + # Document -> DocumentByConnectorCredentialPair -> ConnectorCredentialPair -> + # DocumentSet__ConnectorCredentialPair -> DocumentSet + .outerjoin( DocumentByConnectorCredentialPair, + Document.id == DocumentByConnectorCredentialPair.id, + ) + .outerjoin( + valid_cc_pairs_subquery, and_( DocumentByConnectorCredentialPair.connector_id - == ConnectorCredentialPair.connector_id, + == valid_cc_pairs_subquery.connector_id, DocumentByConnectorCredentialPair.credential_id - == ConnectorCredentialPair.credential_id, + == valid_cc_pairs_subquery.credential_id, ), ) - .join( - Document, - Document.id == DocumentByConnectorCredentialPair.id, + .outerjoin( + valid_document_set__cc_pairs_subquery, + valid_cc_pairs_subquery.id + == valid_document_set__cc_pairs_subquery.connector_credential_pair_id, + ) + .outerjoin( + DocumentSetDBModel, + DocumentSetDBModel.id + == valid_document_set__cc_pairs_subquery.document_set_id, ) .where(Document.id.in_(document_ids)) - # don't include CC pairs that are being deleted - # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them - # as we can assume their document sets are no longer relevant - .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) - .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712 .group_by(Document.id) ) return db_session.execute(stmt).all() # type: ignore diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 79557f209dc..6df1f1f5051 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -1,3 +1,5 @@ +from datetime import datetime +from datetime import timezone from uuid import UUID from fastapi import HTTPException @@ -24,7 +26,6 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair from danswer.db.models import UserRole from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.interfaces import UpdateRequest from danswer.utils.logger import setup_logger logger = setup_logger() @@ -123,12 +124,11 @@ def update_document_boost( db_session: Session, document_id: str, boost: int, - document_index: DocumentIndex, user: User | None = None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) - result = db_session.execute(stmt).scalar_one_or_none() + result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() if result is None: raise HTTPException( status_code=400, detail="Document is not editable by this user" @@ -136,13 +136,9 @@ def update_document_boost( result.boost = boost - update = UpdateRequest( - document_ids=[document_id], - boost=boost, - ) - - document_index.update(update_requests=[update]) - + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + result.last_modified = datetime.now(timezone.utc) db_session.commit() @@ -163,13 +159,9 @@ def update_document_hidden( result.hidden = hidden - update = UpdateRequest( - document_ids=[document_id], - hidden=hidden, - ) - - document_index.update(update_requests=[update]) - + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + result.last_modified = datetime.now(timezone.utc) db_session.commit() @@ -210,11 +202,9 @@ def create_doc_retrieval_feedback( SearchFeedbackType.REJECT, SearchFeedbackType.HIDE, ]: - update = UpdateRequest( - document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden - ) - # Updates are generally batched for efficiency, this case only 1 doc/value is updated - document_index.update(update_requests=[update]) + # updating last_modified triggers sync + # TODO: Should this submit to the queue directly so that the UI can update? + db_doc.last_modified = datetime.now(timezone.utc) db_session.add(retrieval_feedback) db_session.commit() diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 0932d500bbd..32e20d065c0 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -181,6 +181,45 @@ def get_last_attempt( return db_session.execute(stmt).scalars().first() +def get_latest_index_attempts_by_status( + secondary_index: bool, + db_session: Session, + status: IndexingStatus, +) -> Sequence[IndexAttempt]: + """ + Retrieves the most recent index attempt with the specified status for each connector_credential_pair. + Filters attempts based on the secondary_index flag to get either future or present index attempts. + Returns a sequence of IndexAttempt objects, one for each unique connector_credential_pair. + """ + latest_failed_attempts = ( + select( + IndexAttempt.connector_credential_pair_id, + func.max(IndexAttempt.id).label("max_failed_id"), + ) + .join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id) + .where( + SearchSettings.status + == ( + IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT + ), + IndexAttempt.status == status, + ) + .group_by(IndexAttempt.connector_credential_pair_id) + .subquery() + ) + + stmt = select(IndexAttempt).join( + latest_failed_attempts, + ( + IndexAttempt.connector_credential_pair_id + == latest_failed_attempts.c.connector_credential_pair_id + ) + & (IndexAttempt.id == latest_failed_attempts.c.max_failed_id), + ) + + return db_session.execute(stmt).scalars().all() + + def get_latest_index_attempts( secondary_index: bool, db_session: Session, @@ -211,12 +250,12 @@ def get_latest_index_attempts( return db_session.execute(stmt).scalars().all() -def get_index_attempts_for_connector( +def count_index_attempts_for_connector( db_session: Session, connector_id: int, only_current: bool = True, disinclude_finished: bool = False, -) -> Sequence[IndexAttempt]: +) -> int: stmt = ( select(IndexAttempt) .join(ConnectorCredentialPair) @@ -232,23 +271,60 @@ def get_index_attempts_for_connector( stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.PRESENT ) + # Count total items for pagination + count_stmt = stmt.with_only_columns(func.count()).order_by(None) + total_count = db_session.execute(count_stmt).scalar_one() + return total_count - stmt = stmt.order_by(IndexAttempt.time_created.desc()) - return db_session.execute(stmt).scalars().all() +def get_paginated_index_attempts_for_cc_pair_id( + db_session: Session, + connector_id: int, + page: int, + page_size: int, + only_current: bool = True, + disinclude_finished: bool = False, +) -> list[IndexAttempt]: + stmt = ( + select(IndexAttempt) + .join(ConnectorCredentialPair) + .where(ConnectorCredentialPair.connector_id == connector_id) + ) + if disinclude_finished: + stmt = stmt.where( + IndexAttempt.status.in_( + [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] + ) + ) + if only_current: + stmt = stmt.join(SearchSettings).where( + SearchSettings.status == IndexModelStatus.PRESENT + ) + + stmt = stmt.order_by(IndexAttempt.time_started.desc()) + + # Apply pagination + stmt = stmt.offset((page - 1) * page_size).limit(page_size) -def get_latest_finished_index_attempt_for_cc_pair( + return list(db_session.execute(stmt).scalars().all()) + + +def get_latest_index_attempt_for_cc_pair_id( + db_session: Session, connector_credential_pair_id: int, secondary_index: bool, - db_session: Session, + only_finished: bool = True, ) -> IndexAttempt | None: - stmt = select(IndexAttempt).distinct() + stmt = select(IndexAttempt) stmt = stmt.where( IndexAttempt.connector_credential_pair_id == connector_credential_pair_id, - IndexAttempt.status.not_in( - [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] - ), ) + if only_finished: + stmt = stmt.where( + IndexAttempt.status.not_in( + [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] + ), + ) if secondary_index: stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.FUTURE @@ -295,14 +371,21 @@ def get_index_attempts_for_cc_pair( def delete_index_attempts( - connector_id: int, - credential_id: int, + cc_pair_id: int, db_session: Session, ) -> None: + # First, delete related entries in IndexAttemptErrors + stmt_errors = delete(IndexAttemptError).where( + IndexAttemptError.index_attempt_id.in_( + select(IndexAttempt.id).where( + IndexAttempt.connector_credential_pair_id == cc_pair_id + ) + ) + ) + db_session.execute(stmt_errors) + stmt = delete(IndexAttempt).where( - IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id, - ConnectorCredentialPair.connector_id == connector_id, - ConnectorCredentialPair.credential_id == credential_id, + IndexAttempt.connector_credential_pair_id == cc_pair_id, ) db_session.execute(stmt) diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 152cb130573..a68beadc084 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -6,6 +6,7 @@ from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import SearchSettings from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider( setattr(existing_provider, key, value) else: new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) + db_session.add(new_provider) existing_provider = new_provider db_session.commit() @@ -58,7 +60,7 @@ def upsert_cloud_embedding_provider( def upsert_llm_provider( - db_session: Session, llm_provider: LLMProviderUpsertRequest + llm_provider: LLMProviderUpsertRequest, db_session: Session ) -> FullLLMProvider: existing_llm_provider = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) @@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | def remove_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> None: + db_session.execute( + delete(SearchSettings).where(SearchSettings.provider_type == provider_type) + ) + + # Delete the embedding provider db_session.execute( delete(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) + db_session.commit() + def remove_llm_provider(db_session: Session, provider_id: int) -> None: # Remove LLMProvider's dependent relationships @@ -178,7 +187,7 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None: db_session.commit() -def update_default_provider(db_session: Session, provider_id: int) -> None: +def update_default_provider(provider_id: int, db_session: Session) -> None: new_default = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.id == provider_id) ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cdec323961..c0d24770704 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -61,7 +61,7 @@ class Base(DeclarativeBase): - pass + __abstract__ = True class EncryptedString(TypeDecorator): @@ -157,6 +157,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): notifications: Mapped[list["Notification"]] = relationship( "Notification", back_populates="user" ) + # Whether the user has logged in via web. False if user has only used Danswer through Slack bot + has_web_login: Mapped[bool] = mapped_column(Boolean, default=True) class InputPrompt(Base): @@ -373,6 +375,9 @@ class ConnectorCredentialPair(Base): connector_id: Mapped[int] = mapped_column( ForeignKey("connector.id"), primary_key=True ) + + deletion_failure_message: Mapped[str | None] = mapped_column(String, nullable=True) + credential_id: Mapped[int] = mapped_column( ForeignKey("credential.id"), primary_key=True ) @@ -426,12 +431,27 @@ class Document(Base): semantic_id: Mapped[str] = mapped_column(String) # First Section's link link: Mapped[str | None] = mapped_column(String, nullable=True) + # The updated time is also used as a measure of the last successful state of the doc # pulled from the source (to help skip reindexing already updated docs in case of # connector retries) + # TODO: rename this column because it conflates the time of the source doc + # with the local last modified time of the doc and any associated metadata + # it should just be the server timestamp of the source doc doc_updated_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) + + # last time any vespa relevant row metadata or the doc changed. + # does not include last_synced + last_modified: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=False, index=True, default=func.now() + ) + + # last successful sync to vespa + last_synced: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, index=True + ) # The following are not attached to User because the account/email may not be known # within Danswer # Something like the document creator @@ -448,7 +468,7 @@ class Document(Base): ) tags = relationship( "Tag", - secondary="document__tag", + secondary=Document__Tag.__table__, back_populates="documents", ) @@ -465,7 +485,7 @@ class Tag(Base): documents = relationship( "Document", - secondary="document__tag", + secondary=Document__Tag.__table__, back_populates="tags", ) @@ -576,6 +596,8 @@ class SearchSettings(Base): Enum(RerankerProvider, native_enum=False), nullable=True ) rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True) + rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True) + num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS) cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( @@ -607,6 +629,10 @@ def __repr__(self) -> str: return f"" + @property + def api_url(self) -> str | None: + return self.cloud_provider.api_url if self.cloud_provider is not None else None + @property def api_key(self) -> str | None: return self.cloud_provider.api_key if self.cloud_provider is not None else None @@ -671,7 +697,11 @@ class IndexAttempt(Base): "SearchSettings", back_populates="index_attempts" ) - error_rows = relationship("IndexAttemptError", back_populates="index_attempt") + error_rows = relationship( + "IndexAttemptError", + back_populates="index_attempt", + cascade="all, delete-orphan", + ) __table_args__ = ( Index( @@ -806,7 +836,7 @@ class SearchDoc(Base): chat_messages = relationship( "ChatMessage", - secondary="chat_message__search_doc", + secondary=ChatMessage__SearchDoc.__table__, back_populates="search_docs", ) @@ -949,7 +979,7 @@ class ChatMessage(Base): ) search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", - secondary="chat_message__search_doc", + secondary=ChatMessage__SearchDoc.__table__, back_populates="chat_messages", ) # NOTE: Should always be attached to the `assistant` message. @@ -1085,6 +1115,7 @@ class CloudEmbeddingProvider(Base): provider_type: Mapped[EmbeddingProvider] = mapped_column( Enum(EmbeddingProvider), primary_key=True ) + api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_key: Mapped[str | None] = mapped_column(EncryptedString()) search_settings: Mapped[list["SearchSettings"]] = relationship( "SearchSettings", @@ -1400,7 +1431,7 @@ class TaskQueueState(Base): __tablename__ = "task_queue_jobs" id: Mapped[int] = mapped_column(primary_key=True) - # Celery task id + # Celery task id. currently only for readability/diagnostics task_id: Mapped[str] = mapped_column(String) # For any job type, this would be the same task_name: Mapped[str] = mapped_column(String) diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 1d0c218e10a..bb869c471dc 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -1,3 +1,5 @@ +from sqlalchemy import and_ +from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,10 +15,12 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.db.llm import fetch_embedding_provider from danswer.db.models import CloudEmbeddingProvider +from danswer.db.models import IndexAttempt from danswer.db.models import IndexModelStatus from danswer.db.models import SearchSettings from danswer.indexing.models import IndexingSetting from danswer.natural_language_processing.search_nlp_models import clean_model_name +from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder from danswer.search.models import SavedSearchSettings from danswer.server.manage.embedding.models import ( CloudEmbeddingProvider as ServerCloudEmbeddingProvider, @@ -89,6 +93,30 @@ def get_current_db_embedding_provider( return current_embedding_provider +def delete_search_settings(db_session: Session, search_settings_id: int) -> None: + current_settings = get_current_search_settings(db_session) + + if current_settings.id == search_settings_id: + raise ValueError("Cannot delete currently active search settings") + + # First, delete associated index attempts + index_attempts_query = delete(IndexAttempt).where( + IndexAttempt.search_settings_id == search_settings_id + ) + db_session.execute(index_attempts_query) + + # Then, delete the search settings + search_settings_query = delete(SearchSettings).where( + and_( + SearchSettings.id == search_settings_id, + SearchSettings.status != IndexModelStatus.PRESENT, + ) + ) + + db_session.execute(search_settings_query) + db_session.commit() + + def get_current_search_settings(db_session: Session) -> SearchSettings: query = ( select(SearchSettings) @@ -115,6 +143,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_all_search_settings(db_session: Session) -> list[SearchSettings]: + query = select(SearchSettings).order_by(SearchSettings.id.desc()) + result = db_session.execute(query) + all_settings = result.scalars().all() + return list(all_settings) + + def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: with Session(get_sqlalchemy_engine()) as db_session: @@ -146,6 +181,14 @@ def update_current_search_settings( logger.warning("No current search settings found to update") return + # Whenever we update the current search settings, we should ensure that the local reranking model is warmed up. + if ( + current_settings.provider_type is None + and search_settings.rerank_model_name is not None + and current_settings.rerank_model_name != search_settings.rerank_model_name + ): + warm_up_cross_encoder(search_settings.rerank_model_name) + update_search_settings(current_settings, search_settings, preserved_fields) db_session.commit() logger.info("Current search settings updated successfully") @@ -234,6 +277,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + api_url=None, ) @@ -246,4 +290,5 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + api_url=None, ) diff --git a/backend/danswer/db/tasks.py b/backend/danswer/db/tasks.py index 23a7edc9882..a7aec90d260 100644 --- a/backend/danswer/db/tasks.py +++ b/backend/danswer/db/tasks.py @@ -44,12 +44,11 @@ def get_latest_task_by_type( def register_task( - task_id: str, task_name: str, db_session: Session, ) -> TaskQueueState: new_task = TaskQueueState( - task_id=task_id, task_name=task_name, status=TaskStatus.PENDING + task_id="", task_name=task_name, status=TaskStatus.PENDING ) db_session.add(new_task) diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index d824ccfd921..61ba6e475fe 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,9 +1,11 @@ from collections.abc import Sequence from uuid import UUID +from fastapi_users.password import PasswordHelper from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.db.models import User @@ -30,3 +32,22 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: user = db_session.query(User).filter(User.id == user_id).first() # type: ignore return user + + +def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + fastapi_users_pw_helper = PasswordHelper() + password = fastapi_users_pw_helper.generate() + hashed_pass = fastapi_users_pw_helper.hash(password) + user = User( + email=email, + hashed_password=hashed_pass, + has_web_login=False, + role=UserRole.BASIC, + ) + db_session.add(user) + db_session.commit() + return user diff --git a/backend/danswer/document_index/vespa/app_config/services.xml b/backend/danswer/document_index/vespa/app_config/services.xml index 01f2c191ac6..03604d1070c 100644 --- a/backend/danswer/document_index/vespa/app_config/services.xml +++ b/backend/danswer/document_index/vespa/app_config/services.xml @@ -26,6 +26,17 @@ 0.75 + + + + + + SEARCH_THREAD_NUMBER + + + + + 3 750 @@ -33,4 +44,4 @@ 300 - + \ No newline at end of file diff --git a/backend/danswer/document_index/vespa/chunk_retrieval.py b/backend/danswer/document_index/vespa/chunk_retrieval.py index 6a7427630b8..e4b2ad83ce2 100644 --- a/backend/danswer/document_index/vespa/chunk_retrieval.py +++ b/backend/danswer/document_index/vespa/chunk_retrieval.py @@ -30,6 +30,7 @@ from danswer.document_index.vespa_constants import HIDDEN from danswer.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from danswer.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE +from danswer.document_index.vespa_constants import MAX_OR_CONDITIONS from danswer.document_index.vespa_constants import METADATA from danswer.document_index.vespa_constants import METADATA_SUFFIX from danswer.document_index.vespa_constants import PRIMARY_OWNERS @@ -292,12 +293,11 @@ def query_vespa( if LOG_VESPA_TIMING_INFORMATION else {}, ) - - response = requests.post( - SEARCH_ENDPOINT, - json=params, - ) try: + response = requests.post( + SEARCH_ENDPOINT, + json=params, + ) response.raise_for_status() except requests.HTTPError as e: request_info = f"Headers: {response.request.headers}\nPayload: {params}" @@ -319,6 +319,12 @@ def query_vespa( logger.debug("Vespa timing info: %s", response_json.get("timing")) hits = response_json["root"].get("children", []) + if not hits: + logger.warning( + f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}" + ) + logger.debug(f"Vespa Response: {response.text}") + for hit in hits: if hit["fields"].get(CONTENT) is None: identifier = hit["fields"].get("documentid") or hit["id"] @@ -379,7 +385,7 @@ def batch_search_api_retrieval( capped_requests: list[VespaChunkRequest] = [] uncapped_requests: list[VespaChunkRequest] = [] chunk_count = 0 - for request in chunk_requests: + for req_ind, request in enumerate(chunk_requests, start=1): # All requests without a chunk range are uncapped # Uncapped requests are retrieved using the Visit API range = request.range @@ -387,9 +393,10 @@ def batch_search_api_retrieval( uncapped_requests.append(request) continue - # If adding the range to the chunk count is greater than the - # max query size, we need to perform a retrieval to avoid hitting the limit - if chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE: + if ( + chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE + or req_ind % MAX_OR_CONDITIONS == 0 + ): retrieved_chunks.extend( _get_chunks_via_batch_search( index_name=index_name, diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index d07da5b06bb..0153f372fd4 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -16,6 +16,7 @@ from danswer.configs.chat_configs import DOC_TIME_DECAY from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.chat_configs import TITLE_CONTENT_RATIO +from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS from danswer.configs.constants import KV_REINDEX_KEY from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import DocumentInsertionRecord @@ -52,6 +53,7 @@ from danswer.document_index.vespa_constants import DOCUMENT_SETS from danswer.document_index.vespa_constants import HIDDEN from danswer.document_index.vespa_constants import NUM_THREADS +from danswer.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT from danswer.document_index.vespa_constants import VESPA_TIMEOUT @@ -134,6 +136,10 @@ def ensure_indices_exist( doc_lines = _create_document_xml_lines(schema_names) services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines) + services = services.replace( + SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS) + ) + kv_store = get_dynamic_config_store() needs_reindexing = False @@ -282,7 +288,7 @@ def _update_chunk( raise requests.HTTPError(failure_msg) from e def update(self, update_requests: list[UpdateRequest]) -> None: - logger.info(f"Updating {len(update_requests)} documents in Vespa") + logger.debug(f"Updating {len(update_requests)} documents in Vespa") # Handle Vespa character limitations # Mutating update_requests but it's not used later anyway diff --git a/backend/danswer/document_index/vespa/indexing_utils.py b/backend/danswer/document_index/vespa/indexing_utils.py index 1b16cfc4947..6b6ba8709d5 100644 --- a/backend/danswer/document_index/vespa/indexing_utils.py +++ b/backend/danswer/document_index/vespa/indexing_utils.py @@ -162,14 +162,16 @@ def _index_vespa_chunk( METADATA_SUFFIX: chunk.metadata_suffix_keyword, EMBEDDINGS: embeddings_name_vector_map, TITLE_EMBEDDING: chunk.title_embedding, - BOOST: chunk.boost, DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at), PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners), SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners), # the only `set` vespa has is `weightedset`, so we have to give each # element an arbitrary weight + # rkuo: acl, docset and boost metadata are also updated through the metadata sync queue + # which only calls VespaIndex.update ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()}, DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, + BOOST: chunk.boost, } vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}" diff --git a/backend/danswer/document_index/vespa_constants.py b/backend/danswer/document_index/vespa_constants.py index 0b8949b4264..8409efe1dea 100644 --- a/backend/danswer/document_index/vespa_constants.py +++ b/backend/danswer/document_index/vespa_constants.py @@ -7,6 +7,7 @@ VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM" DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME" DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT" +SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER" DATE_REPLACEMENT = "DATE_REPLACEMENT" # config server @@ -25,6 +26,9 @@ 32 # since Vespa doesn't allow batching of inserts / updates, we use threads ) MAX_ID_SEARCH_QUERY_SIZE = 400 +# Suspect that adding too many "or" conditions will cause Vespa to timeout and return +# an empty list of hits (with no error status and coverage: 0 and degraded) +MAX_OR_CONDITIONS = 10 # up from 500ms for now, since we've seen quite a few timeouts # in the long term, we are looking to improve the performance of Vespa # so that we can bring this back to default diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 7143b428714..36df08ac465 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -8,6 +8,7 @@ from email.parser import Parser as EmailParser from pathlib import Path from typing import Any +from typing import Dict from typing import IO import chardet @@ -178,6 +179,17 @@ def read_text_file( def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: + """Extract text from a PDF file.""" + # Return only the extracted text from read_pdf_file + text, _ = read_pdf_file(file, pdf_pass) + return text + + +def read_pdf_file( + file: IO[Any], + pdf_pass: str | None = None, +) -> tuple[str, dict]: + metadata: Dict[str, Any] = {} try: pdf_reader = PdfReader(file) @@ -189,16 +201,33 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: decrypt_success = pdf_reader.decrypt(pdf_pass) != 0 except Exception: logger.error("Unable to decrypt pdf") - else: - logger.warning("No Password available to to decrypt pdf") if not decrypt_success: # By user request, keep files that are unreadable just so they # can be discoverable by title. - return "" - - return TEXT_SECTION_SEPARATOR.join( - page.extract_text() for page in pdf_reader.pages + return "", metadata + else: + logger.warning("No Password available to to decrypt pdf") + + # Extract metadata from the PDF, removing leading '/' from keys if present + # This standardizes the metadata keys for consistency + metadata = {} + if pdf_reader.metadata is not None: + for key, value in pdf_reader.metadata.items(): + clean_key = key.lstrip("/") + if isinstance(value, str) and value.strip(): + metadata[clean_key] = value + + elif isinstance(value, list) and all( + isinstance(item, str) for item in value + ): + metadata[clean_key] = ", ".join(value) + + return ( + TEXT_SECTION_SEPARATOR.join( + page.extract_text() for page in pdf_reader.pages + ), + metadata, ) except PdfStreamError: logger.exception("PDF file is not a valid PDF") @@ -207,13 +236,47 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: # File is still discoverable by title # but the contents are not included as they cannot be parsed - return "" + return "", metadata def docx_to_text(file: IO[Any]) -> str: + def is_simple_table(table: docx.table.Table) -> bool: + for row in table.rows: + # No omitted cells + if row.grid_cols_before > 0 or row.grid_cols_after > 0: + return False + + # No nested tables + if any(cell.tables for cell in row.cells): + return False + + return True + + def extract_cell_text(cell: docx.table._Cell) -> str: + cell_paragraphs = [para.text.strip() for para in cell.paragraphs] + return " ".join(p for p in cell_paragraphs if p) or "N/A" + + paragraphs = [] doc = docx.Document(file) - full_text = [para.text for para in doc.paragraphs] - return TEXT_SECTION_SEPARATOR.join(full_text) + for item in doc.iter_inner_content(): + if isinstance(item, docx.text.paragraph.Paragraph): + paragraphs.append(item.text) + + elif isinstance(item, docx.table.Table): + if not item.rows or not is_simple_table(item): + continue + + # Every row is a new line, joined with a single newline + table_content = "\n".join( + [ + ",\t".join(extract_cell_text(cell) for cell in row.cells) + for row in item.rows + ] + ) + paragraphs.append(table_content) + + # Docx already has good spacing between paragraphs + return "\n".join(paragraphs) def pptx_to_text(file: IO[Any]) -> str: diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 4b849f70d96..b71d20bbbb4 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,4 +1,6 @@ +from collections.abc import Callable from io import BytesIO +from typing import Any from typing import cast from uuid import uuid4 @@ -73,5 +75,7 @@ def save_file_from_url(url: str) -> str: def save_files_from_urls(urls: list[str]) -> list[str]: - funcs = [(save_file_from_url, (url,)) for url in urls] + funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ + (save_file_from_url, (url,)) for url in urls + ] return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index f7d8f4e7400..d25a0659c62 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -32,6 +32,7 @@ def __init__( passage_prefix: str | None, provider_type: EmbeddingProvider | None, api_key: str | None, + api_url: str | None, ): self.model_name = model_name self.normalize = normalize @@ -39,6 +40,7 @@ def __init__( self.passage_prefix = passage_prefix self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url self.embedding_model = EmbeddingModel( model_name=model_name, @@ -47,6 +49,7 @@ def __init__( normalize=normalize, api_key=api_key, provider_type=provider_type, + api_url=api_url, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, @@ -70,9 +73,16 @@ def __init__( passage_prefix: str | None, provider_type: EmbeddingProvider | None = None, api_key: str | None = None, + api_url: str | None = None, ): super().__init__( - model_name, normalize, query_prefix, passage_prefix, provider_type, api_key + model_name, + normalize, + query_prefix, + passage_prefix, + provider_type, + api_key, + api_url, ) @log_function_time() @@ -156,7 +166,7 @@ def embed_chunks( title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.model_dump(), + **chunk.dict(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], @@ -179,6 +189,7 @@ def from_db_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) @@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index de62133fc09..51cd23e7431 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from danswer.access.access import get_access_for_documents +from danswer.access.models import DocumentAccess from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT from danswer.configs.constants import DEFAULT_BOOST @@ -17,7 +18,8 @@ from danswer.connectors.models import IndexAttemptMetadata from danswer.db.document import get_documents_by_ids from danswer.db.document import prepare_to_modify_documents -from danswer.db.document import update_docs_updated_at +from danswer.db.document import update_docs_last_modified__no_commit +from danswer.db.document import update_docs_updated_at__no_commit from danswer.db.document import upsert_documents_complete from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.index_attempt import create_index_attempt_error @@ -263,6 +265,8 @@ def index_doc_batch( Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" + no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + ctx = index_doc_batch_prepare( document_batch=document_batch, index_attempt_metadata=index_attempt_metadata, @@ -292,9 +296,6 @@ def index_doc_batch( # NOTE: don't need to acquire till here, since this is when the actual race condition # with Vespa can occur. with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids): - # Attach the latest status from Postgres (source of truth for access) to each - # chunk. This access status will be attached to each chunk in the document index - # TODO: attach document sets to the chunk based on the status of Postgres as well document_id_to_access_info = get_access_for_documents( document_ids=updatable_ids, db_session=db_session ) @@ -304,10 +305,18 @@ def index_doc_batch( document_ids=updatable_ids, db_session=db_session ) } + + # we're concerned about race conditions where multiple simultaneous indexings might result + # in one set of metadata overwriting another one in vespa. + # we still write data here for immediate and most likely correct sync, but + # to resolve this, an update of the last modified field at the end of this loop + # always triggers a final metadata sync access_aware_chunks = [ DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, - access=document_id_to_access_info[chunk.source_document.id], + access=document_id_to_access_info.get( + chunk.source_document.id, no_access + ), document_sets=set( document_id_to_document_set.get(chunk.source_document.id, []) ), @@ -333,18 +342,24 @@ def index_doc_batch( doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids ] - # Update the time of latest version of the doc successfully indexed - ids_to_new_updated_at = {} - for doc in successful_docs: - if doc.doc_updated_at is None: - continue - ids_to_new_updated_at[doc.id] = doc.doc_updated_at + last_modified_ids = [] + ids_to_new_updated_at = {} + for doc in successful_docs: + last_modified_ids.append(doc.id) + # doc_updated_at is the connector source's idea of when the doc was last modified + if doc.doc_updated_at is None: + continue + ids_to_new_updated_at[doc.id] = doc.doc_updated_at + + update_docs_updated_at__no_commit( + ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session + ) - update_docs_updated_at( - ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session - ) + update_docs_last_modified__no_commit( + document_ids=last_modified_ids, db_session=db_session + ) - db_session.commit() + db_session.commit() return len([r for r in insertion_records if r.already_existed is False]), len( access_aware_chunks diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index b23de0eb477..c789a2b351b 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -61,6 +61,8 @@ class IndexChunk(DocAwareChunk): title_embedding: Embedding | None +# TODO(rkuo): currently, this extra metadata sent during indexing is just for speed, +# but full consistency happens on background sync class DocMetadataAwareIndexChunk(IndexChunk): """An `IndexChunk` that contains all necessary metadata to be indexed. This includes the following: @@ -95,10 +97,12 @@ def from_index_chunk( class EmbeddingModelDetail(BaseModel): + id: int | None = None model_name: str normalize: bool query_prefix: str | None passage_prefix: str | None + api_url: str | None = None provider_type: EmbeddingProvider | None = None api_key: str | None = None @@ -111,12 +115,14 @@ def from_db_model( search_settings: "SearchSettings", ) -> "EmbeddingModelDetail": return cls( + id=search_settings.id, model_name=search_settings.model_name, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 630c0c70229..a036421da2e 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,5 +1,6 @@ from collections.abc import Callable from collections.abc import Iterator +from typing import Any from typing import cast from uuid import uuid4 @@ -12,6 +13,8 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc +from danswer.chat.models import StreamStopInfo +from danswer.chat.models import StreamStopReason from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import AnswerStyleConfig @@ -35,7 +38,7 @@ from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping from danswer.llm.answering.stream_processing.utils import map_document_id_order from danswer.llm.interfaces import LLM -from danswer.llm.utils import message_generator_to_string_generator +from danswer.llm.interfaces import ToolChoiceOptions from danswer.natural_language_processing.utils import get_tokenizer from danswer.tools.custom.custom_tool_prompt_builder import ( build_user_message_for_custom_tool_for_non_tool_calling_llm, @@ -49,7 +52,7 @@ from danswer.tools.internet_search.internet_search_tool import InternetSearchTool from danswer.tools.message import build_tool_message from danswer.tools.message import ToolCallSummary -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID from danswer.tools.search.search_tool import SearchResponseSummary @@ -190,7 +193,9 @@ def _update_prompt_builder_for_search_tool( def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -225,6 +230,7 @@ def _raw_output_for_explicit_tool_calling_llms( self.tools, self.force_use_tool ) ] + for message in self.llm.stream( prompt=prompt, tools=final_tool_definitions if final_tool_definitions else None, @@ -242,6 +248,13 @@ def _raw_output_for_explicit_tool_calling_llms( if self.is_cancelled: return yield cast(str, message.content) + if ( + message.additional_kwargs.get("usage_metadata", {}).get("stop") + == "length" + ): + yield StreamStopInfo( + stop_reason=StreamStopReason.CONTEXT_LENGTH + ) if not tool_call_chunk: return # no tool call needed @@ -298,21 +311,41 @@ def _raw_output_for_explicit_tool_calling_llms( yield tool_runner.tool_final_result() prompt = prompt_builder.build(tool_call_summary=tool_call_summary) - for token in message_generator_to_string_generator( - self.llm.stream( - prompt=prompt, - tools=[tool.tool_definition() for tool in self.tools], - ) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream( + prompt=prompt, + tools=[tool.tool_definition() for tool in self.tools], + ) return + # This method processes the LLM stream and yields the content or stop information + def _process_llm_stream( + self, + prompt: Any, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[str | StreamStopInfo]: + for message in self.llm.stream( + prompt=prompt, tools=tools, tool_choice=tool_choice + ): + if isinstance(message, AIMessageChunk): + if message.content: + if self.is_cancelled: + return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) + yield cast(str, message.content) + + if ( + message.additional_kwargs.get("usage_metadata", {}).get("stop") + == "length" + ): + yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH) + def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: + ) -> Iterator[ + str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult + ]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -387,13 +420,10 @@ def _raw_output_for_non_explicit_tool_calling_llms( ) ) prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token - + yield from self._process_llm_stream( + prompt=prompt, + tools=None, + ) return tool, tool_args = chosen_tool_and_args @@ -403,7 +433,7 @@ def _raw_output_for_non_explicit_tool_calling_llms( if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}: final_context_documents = None for response in tool_runner.tool_responses(): - if response.id == FINAL_CONTEXT_DOCUMENTS: + if response.id == FINAL_CONTEXT_DOCUMENTS_ID: final_context_documents = cast(list[LlmDoc], response.response) yield response @@ -447,12 +477,8 @@ def _raw_output_for_non_explicit_tool_calling_llms( yield final prompt = prompt_builder.build() - for token in message_generator_to_string_generator( - self.llm.stream(prompt=prompt) - ): - if self.is_cancelled: - return - yield token + + yield from self._process_llm_stream(prompt=prompt, tools=None) @property def processed_streamed_output(self) -> AnswerStream: @@ -470,17 +496,15 @@ def processed_streamed_output(self) -> AnswerStream: ) def _process_stream( - stream: Iterator[ToolCallKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo], ) -> AnswerStream: message = None # special things we need to keep track of for the SearchTool - search_results: list[LlmDoc] | None = ( - None # raw results that will be displayed to the user - ) - final_context_docs: list[LlmDoc] | None = ( - None # processed docs to feed into the LLM - ) + # raw results that will be displayed to the user + search_results: list[LlmDoc] | None = None + # processed docs to feed into the LLM + final_context_docs: list[LlmDoc] | None = None for message in stream: if isinstance(message, ToolCallKickoff) or isinstance( @@ -499,8 +523,9 @@ def _process_stream( SearchResponseSummary, message.response ).top_sections ] - elif message.id == FINAL_CONTEXT_DOCUMENTS: + elif message.id == FINAL_CONTEXT_DOCUMENTS_ID: final_context_docs = cast(list[LlmDoc], message.response) + yield message elif ( message.id == SEARCH_DOC_CONTENT_ID @@ -524,13 +549,22 @@ def _process_stream( answer_style_configs=self.answer_style_config, ) + stream_stop_info = None + def _stream() -> Iterator[str]: - if message: - yield cast(str, message) - yield from cast(Iterator[str], stream) + nonlocal stream_stop_info + yield cast(str, message) + for item in stream: + if isinstance(item, StreamStopInfo): + stream_stop_info = item + return + yield cast(str, item) yield from process_answer_stream_fn(_stream()) + if stream_stop_info: + yield stream_stop_info + processed_stream = [] for processed_packet in _process_stream(output_generator): processed_stream.append(processed_packet) diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index de80b6f6756..a72fc70a8ff 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -11,7 +11,6 @@ from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -204,7 +203,9 @@ def extract_citations_from_stream( def build_citation_processor( context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping ) -> StreamProcessor: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from extract_citations_from_stream( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py index 74f37b85264..501a56b5aa7 100644 --- a/backend/danswer/llm/answering/stream_processing/quotes_processing.py +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -285,7 +285,9 @@ def process_model_tokens( def build_quotes_processor( context_docs: list[LlmDoc], is_json_prompt: bool ) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: - def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + def stream_processor( + tokens: Iterator[str], + ) -> AnswerQuestionStreamReturn: yield from process_model_tokens( tokens=tokens, context_docs=context_docs, diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 33b1cc24c81..08131f581a4 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -25,9 +25,6 @@ from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_API_VERSION -from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig @@ -141,7 +138,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: def _convert_delta_to_message_chunk( - _dict: dict[str, Any], curr_msg: BaseMessage | None + _dict: dict[str, Any], + curr_msg: BaseMessage | None, + stop_reason: str | None = None, ) -> BaseMessageChunk: """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) @@ -166,12 +165,23 @@ def _convert_delta_to_message_chunk( args=tool_call.function.arguments, index=0, # only support a single tool call atm ) + return AIMessageChunk( content=content, - additional_kwargs=additional_kwargs, tool_call_chunks=[tool_call_chunk], + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, ) - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + + return AIMessageChunk( + content=content, + additional_kwargs={ + "usage_metadata": {"stop": stop_reason}, + **additional_kwargs, + }, + ) elif role == "system": return SystemMessageChunk(content=content) elif role == "function": @@ -192,10 +202,10 @@ def __init__( timeout: int, model_provider: str, model_name: str, + api_base: str | None = None, + api_version: str | None = None, max_output_tokens: int | None = None, - api_base: str | None = GEN_AI_API_ENDPOINT, - api_version: str | None = GEN_AI_API_VERSION, - custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE, + custom_llm_provider: str | None = None, temperature: float = GEN_AI_TEMPERATURE, custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, @@ -209,7 +219,7 @@ def __init__( self._api_version = api_version self._custom_llm_provider = custom_llm_provider - # This can be used to store the maximum output tkoens for this model. + # This can be used to store the maximum output tokens for this model. # self._max_output_tokens = ( # max_output_tokens # if max_output_tokens is not None @@ -352,10 +362,16 @@ def _stream_implementation( ) try: for part in response: - if len(part["choices"]) == 0: + if not part["choices"]: continue - delta = part["choices"][0]["delta"] - message_chunk = _convert_delta_to_message_chunk(delta, output) + + choice = part["choices"][0] + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], + output, + stop_reason=choice["finish_reason"], + ) + if output is None: output = message_chunk else: diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 967e014a903..4a5ba7857c3 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -7,7 +7,6 @@ from langchain_core.messages import BaseMessage from requests import Timeout -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.llm.interfaces import LLM from danswer.llm.interfaces import ToolChoiceOptions @@ -37,7 +36,7 @@ def __init__( # Not used here but you probably want a model server that isn't completely open api_key: str | None, timeout: int, - endpoint: str | None = GEN_AI_API_ENDPOINT, + endpoint: str, max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ): if not endpoint: diff --git a/backend/danswer/llm/llm_initialization.py b/backend/danswer/llm/llm_initialization.py deleted file mode 100644 index fef17ca812d..00000000000 --- a/backend/danswer/llm/llm_initialization.py +++ /dev/null @@ -1,80 +0,0 @@ -from sqlalchemy.orm import Session - -from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_API_VERSION -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.db.llm import fetch_existing_llm_providers -from danswer.db.llm import update_default_provider -from danswer.db.llm import upsert_llm_provider -from danswer.llm.llm_provider_options import AZURE_PROVIDER_NAME -from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME -from danswer.llm.llm_provider_options import fetch_available_well_known_llms -from danswer.server.manage.llm.models import LLMProviderUpsertRequest -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def load_llm_providers(db_session: Session) -> None: - existing_providers = fetch_existing_llm_providers(db_session) - if existing_providers: - return - - if not GEN_AI_API_KEY or DISABLE_GENERATIVE_AI: - return - - well_known_provider_name_to_provider = { - provider.name: provider - for provider in fetch_available_well_known_llms() - if provider.name != BEDROCK_PROVIDER_NAME - } - - if GEN_AI_MODEL_PROVIDER not in well_known_provider_name_to_provider: - logger.error(f"Cannot auto-transition LLM provider: {GEN_AI_MODEL_PROVIDER}") - return None - - # Azure provider requires custom model names, - # OpenAI / anthropic can just use the defaults - model_names = ( - [ - name - for name in [ - GEN_AI_MODEL_VERSION, - FAST_GEN_AI_MODEL_VERSION, - ] - if name - ] - if GEN_AI_MODEL_PROVIDER == AZURE_PROVIDER_NAME - else None - ) - - well_known_provider = well_known_provider_name_to_provider[GEN_AI_MODEL_PROVIDER] - llm_provider_request = LLMProviderUpsertRequest( - name=well_known_provider.display_name, - provider=GEN_AI_MODEL_PROVIDER, - api_key=GEN_AI_API_KEY, - api_base=GEN_AI_API_ENDPOINT, - api_version=GEN_AI_API_VERSION, - custom_config={}, - default_model_name=( - GEN_AI_MODEL_VERSION - or well_known_provider.default_model - or well_known_provider.llm_names[0] - ), - fast_default_model_name=( - FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model - ), - model_names=model_names, - is_public=True, - display_model_names=[], - ) - llm_provider = upsert_llm_provider(db_session, llm_provider_request) - update_default_provider(db_session, llm_provider.id) - logger.notice( - f"Migrated LLM provider from env variables for provider '{GEN_AI_MODEL_PROVIDER}'" - ) diff --git a/backend/danswer/llm/llm_provider_options.py b/backend/danswer/llm/llm_provider_options.py index 24feeb2f27c..1bcfdf7e506 100644 --- a/backend/danswer/llm/llm_provider_options.py +++ b/backend/danswer/llm/llm_provider_options.py @@ -95,8 +95,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: api_version_required=False, custom_config_keys=[], llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME), - default_model="claude-3-opus-20240229", - default_fast_model="claude-3-sonnet-20240229", + default_model="claude-3-5-sonnet-20240620", + default_fast_model="claude-3-5-sonnet-20240620", ), WellKnownLLMProviderDescriptor( name=AZURE_PROVIDER_NAME, @@ -128,8 +128,8 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: ), ], llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME), - default_model="anthropic.claude-3-sonnet-20240229-v1:0", - default_fast_model="anthropic.claude-3-haiku-20240307-v1:0", + default_model="anthropic.claude-3-5-sonnet-20240620-v1:0", + default_fast_model="anthropic.claude-3-5-sonnet-20240620-v1:0", ), ] diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 82617f3f05b..c367f0aa522 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -32,7 +32,6 @@ from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType @@ -331,7 +330,7 @@ def test_llm(llm: LLM) -> str | None: def get_llm_max_tokens( model_map: dict, model_name: str, - model_provider: str = GEN_AI_MODEL_PROVIDER, + model_provider: str, ) -> int: """Best effort attempt to get the max tokens for the LLM""" if GEN_AI_MAX_TOKENS: @@ -371,7 +370,7 @@ def get_llm_max_tokens( def get_llm_max_output_tokens( model_map: dict, model_name: str, - model_provider: str = GEN_AI_MODEL_PROVIDER, + model_provider: str, ) -> int: """Best effort attempt to get the max output tokens for the LLM""" try: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 6652e5d3c39..a00826f11c8 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,4 +1,5 @@ import time +import traceback from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any @@ -7,7 +8,9 @@ import uvicorn from fastapi import APIRouter from fastapi import FastAPI +from fastapi import HTTPException from fastapi import Request +from fastapi import status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -36,6 +39,9 @@ from danswer.configs.constants import KV_REINDEX_KEY from danswer.configs.constants import KV_SEARCH_SETTINGS from danswer.configs.constants import POSTGRES_WEB_APP_NAME +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.connector import check_connectors_exist from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair @@ -48,6 +54,9 @@ from danswer.db.engine import warm_up_connections from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts +from danswer.db.llm import fetch_default_provider +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider from danswer.db.persona import delete_old_default_personas from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings @@ -60,7 +69,6 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import IndexingSetting -from danswer.llm.llm_initialization import load_llm_providers from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder @@ -91,6 +99,7 @@ from danswer.server.manage.get_state import router as state_router from danswer.server.manage.llm.api import admin_router as llm_admin_router from danswer.server.manage.llm.api import basic_router as llm_router +from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.search_settings import router as search_settings_router from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.standard_answer import router as standard_answer_router @@ -109,7 +118,9 @@ from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache +from danswer.utils.gpu_utils import gpu_status_request from danswer.utils.logger import setup_logger +from danswer.utils.telemetry import get_or_generate_uuid from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation @@ -179,9 +190,6 @@ def setup_postgres(db_session: Session) -> None: logger.notice("Verifying default standard answer category exists.") create_initial_default_standard_answer_category(db_session) - logger.notice("Loading LLM providers from env variables") - load_llm_providers(db_session) - logger.notice("Loading default Prompts and Personas") delete_old_default_personas(db_session) load_chat_yamls() @@ -191,6 +199,58 @@ def setup_postgres(db_session: Session) -> None: refresh_built_in_tools_cache(db_session) auto_add_search_tool_to_personas(db_session) + if GEN_AI_API_KEY and fetch_default_provider(db_session) is None: + # Only for dev flows + logger.notice("Setting up default OpenAI LLM for dev.") + llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini" + fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini" + model_req = LLMProviderUpsertRequest( + name="DevEnvPresetOpenAI", + provider="openai", + api_key=GEN_AI_API_KEY, + api_base=None, + api_version=None, + custom_config=None, + default_model_name=llm_model, + fast_default_model_name=fast_model, + is_public=True, + groups=[], + display_model_names=[llm_model, fast_model], + model_names=[llm_model, fast_model], + ) + new_llm_provider = upsert_llm_provider( + llm_provider=model_req, db_session=db_session + ) + update_default_provider(provider_id=new_llm_provider.id, db_session=db_session) + + +def update_default_multipass_indexing(db_session: Session) -> None: + docs_exist = check_docs_exist(db_session) + connectors_exist = check_connectors_exist(db_session) + logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}") + + if not docs_exist and not connectors_exist: + logger.info( + "No existing docs or connectors found. Checking GPU availability for multipass indexing." + ) + gpu_available = gpu_status_request() + logger.info(f"GPU available: {gpu_available}") + + current_settings = get_current_search_settings(db_session) + + logger.notice(f"Updating multipass indexing setting to: {gpu_available}") + updated_settings = SavedSearchSettings.from_db_model(current_settings) + # Enable multipass indexing if GPU is available or if using a cloud provider + updated_settings.multipass_indexing = ( + gpu_available or current_settings.cloud_provider is not None + ) + update_current_search_settings(db_session, updated_settings) + + else: + logger.debug( + "Existing docs or connectors found. Skipping multipass indexing update." + ) + def translate_saved_search_settings(db_session: Session) -> None: kv_store = get_dynamic_config_store() @@ -260,21 +320,32 @@ def setup_vespa( document_index: DocumentIndex, index_setting: IndexingSetting, secondary_index_setting: IndexingSetting | None, -) -> None: +) -> bool: # Vespa startup is a bit slow, so give it a few seconds - wait_time = 5 - for _ in range(5): + WAIT_SECONDS = 5 + VESPA_ATTEMPTS = 5 + for x in range(VESPA_ATTEMPTS): try: + logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") document_index.ensure_indices_exist( index_embedding_dim=index_setting.model_dim, secondary_index_embedding_dim=secondary_index_setting.model_dim if secondary_index_setting else None, ) - break + + logger.notice("Vespa setup complete.") + return True except Exception: - logger.notice(f"Waiting on Vespa, retrying in {wait_time} seconds...") - time.sleep(wait_time) + logger.notice( + f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds." + ) + time.sleep(WAIT_SECONDS) + + logger.error( + f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})" + ) + return False @asynccontextmanager @@ -297,6 +368,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # fill up Postgres connection pools await warm_up_connections() + # We cache this at the beginning so there is no delay in the first telemetry + get_or_generate_uuid() + with Session(engine) as db_session: check_index_swap(db_session=db_session) search_settings = get_current_search_settings(db_session) @@ -329,8 +403,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice( f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." ) - - if search_settings.rerank_model_name and not search_settings.provider_type: + if ( + search_settings.rerank_model_name + and not search_settings.provider_type + and not search_settings.rerank_provider_type + ): warm_up_cross_encoder(search_settings.rerank_model_name) logger.notice("Verifying query preprocessing (NLTK) data is downloaded") @@ -353,13 +430,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if secondary_search_settings else None, ) - setup_vespa( + + success = setup_vespa( document_index, IndexingSetting.from_db_model(search_settings), IndexingSetting.from_db_model(secondary_search_settings) if secondary_search_settings else None, ) + if not success: + raise RuntimeError( + "Could not connect to Vespa within the specified timeout." + ) logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") if search_settings.provider_type is None: @@ -371,15 +453,41 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: ), ) + # update multipass indexing setting based on GPU availability + update_default_multipass_indexing(db_session) + optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield +def log_http_error(_: Request, exc: Exception) -> JSONResponse: + status_code = getattr(exc, "status_code", 500) + if status_code >= 400: + error_msg = f"{str(exc)}\n" + error_msg += "".join(traceback.format_tb(exc.__traceback__)) + logger.error(error_msg) + + detail = exc.detail if isinstance(exc, HTTPException) else str(exc) + return JSONResponse( + status_code=status_code, + content={"detail": detail}, + ) + + def get_application() -> FastAPI: application = FastAPI( title="Danswer Backend", version=__version__, lifespan=lifespan ) + # Add the custom exception handler + application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error) + application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error) + application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error) + application.add_exception_handler(status.HTTP_404_NOT_FOUND, log_http_error) + application.add_exception_handler( + status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error + ) + include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, document_router) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index b7835c4e906..6dcec724345 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -24,6 +24,8 @@ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider +from shared_configs.model_server_models import ConnectorClassificationRequest +from shared_configs.model_server_models import ConnectorClassificationResponse from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse @@ -90,6 +92,7 @@ def __init__( query_prefix: str | None, passage_prefix: str | None, api_key: str | None, + api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, ) -> None: @@ -100,6 +103,7 @@ def __init__( self.normalize = normalize self.model_name = model_name self.retrim_content = retrim_content + self.api_url = api_url self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) @@ -157,6 +161,7 @@ def _batch_encode_texts( text_type=text_type, manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, + api_url=self.api_url, ) response = self._make_model_server_request(embed_request) @@ -226,6 +231,7 @@ def from_db_model( passage_prefix=search_settings.passage_prefix, api_key=search_settings.api_key, provider_type=search_settings.provider_type, + api_url=search_settings.api_url, retrim_content=retrim_content, ) @@ -236,6 +242,7 @@ def __init__( model_name: str, provider_type: RerankerProvider | None, api_key: str | None, + api_url: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: @@ -244,6 +251,7 @@ def __init__( self.model_name = model_name self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url def predict(self, query: str, passages: list[str]) -> list[float]: rerank_request = RerankRequest( @@ -252,6 +260,7 @@ def predict(self, query: str, passages: list[str]) -> list[float]: model_name=self.model_name, provider_type=self.provider_type, api_key=self.api_key, + api_url=self.api_url, ) response = requests.post( @@ -297,6 +306,37 @@ def predict( return response_model.is_keyword, response_model.keywords +class ConnectorClassificationModel: + def __init__( + self, + model_server_host: str = MODEL_SERVER_HOST, + model_server_port: int = MODEL_SERVER_PORT, + ): + model_server_url = build_model_server_url(model_server_host, model_server_port) + self.connector_classification_endpoint = ( + model_server_url + "/custom/connector-classification" + ) + + def predict( + self, + query: str, + available_connectors: list[str], + ) -> list[str]: + connector_classification_request = ConnectorClassificationRequest( + available_connectors=available_connectors, + query=query, + ) + response = requests.post( + self.connector_classification_endpoint, + json=connector_classification_request.dict(), + ) + response.raise_for_status() + + response_model = ConnectorClassificationResponse(**response.json()) + + return response_model.connectors + + def warm_up_retry( func: Callable[..., Any], tries: int = 20, @@ -312,8 +352,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) except Exception as e: exceptions.append(e) - logger.exception( - f"Attempt {attempt + 1} failed; retrying in {delay} seconds..." + logger.info( + f"Attempt {attempt + 1}/{tries} failed; retrying in {delay} seconds..." ) time.sleep(delay) raise Exception(f"All retries failed: {exceptions}") @@ -363,6 +403,7 @@ def warm_up_cross_encoder( reranking_model = RerankingModel( model_name=rerank_model_name, provider_type=None, + api_url=None, api_key=None, ) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index a5a0fe0dad5..3f83ad19551 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -371,7 +371,7 @@ def get_search_answer( elif isinstance(packet, QADocsResponse): qa_response.docs = packet elif isinstance(packet, LLMRelevanceFilterResponse): - qa_response.llm_chunks_indices = packet.relevant_chunk_indices + qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices elif isinstance(packet, DanswerQuotes): qa_response.quotes = packet elif isinstance(packet, CitationInfo): diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index d7e81975630..fceb78de7aa 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -62,7 +62,7 @@ class OneShotQAResponse(BaseModel): quotes: DanswerQuotes | None = None citations: list[CitationInfo] | None = None docs: QADocsResponse | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None error_msg: str | None = None answer_valid: bool = True # Reflexion result, default True if Reflexion not run chat_message_id: int | None = None diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py new file mode 100644 index 00000000000..1ca2e07ecd3 --- /dev/null +++ b/backend/danswer/redis/redis_pool.py @@ -0,0 +1,65 @@ +import threading +from typing import Optional + +import redis +from redis.client import Redis +from redis.connection import ConnectionPool + +from danswer.configs.app_configs import REDIS_DB_NUMBER +from danswer.configs.app_configs import REDIS_HOST +from danswer.configs.app_configs import REDIS_PASSWORD +from danswer.configs.app_configs import REDIS_PORT +from danswer.configs.app_configs import REDIS_SSL +from danswer.configs.app_configs import REDIS_SSL_CA_CERTS +from danswer.configs.app_configs import REDIS_SSL_CERT_REQS + +REDIS_POOL_MAX_CONNECTIONS = 10 + + +class RedisPool: + _instance: Optional["RedisPool"] = None + _lock: threading.Lock = threading.Lock() + _pool: ConnectionPool + + def __new__(cls) -> "RedisPool": + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(RedisPool, cls).__new__(cls) + cls._instance._init_pool() + return cls._instance + + def _init_pool(self) -> None: + if REDIS_SSL: + # Examples: https://github.com/redis/redis-py/issues/780 + self._pool = redis.ConnectionPool( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB_NUMBER, + password=REDIS_PASSWORD, + max_connections=REDIS_POOL_MAX_CONNECTIONS, + connection_class=redis.SSLConnection, + ssl_ca_certs=REDIS_SSL_CA_CERTS, + ssl_cert_reqs=REDIS_SSL_CERT_REQS, + ) + else: + self._pool = redis.ConnectionPool( + host=REDIS_HOST, + port=REDIS_PORT, + db=REDIS_DB_NUMBER, + password=REDIS_PASSWORD, + max_connections=REDIS_POOL_MAX_CONNECTIONS, + ) + + def get_client(self) -> Redis: + return redis.Redis(connection_pool=self._pool) + + +# # Usage example +# redis_pool = RedisPool() +# redis_client = redis_pool.get_client() + +# # Example of setting and getting a value +# redis_client.set('key', 'value') +# value = redis_client.get('key') +# print(value.decode()) # Output: 'value' diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 15387e6c63e..678877812a2 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -26,6 +26,7 @@ class RerankingDetails(BaseModel): # If model is None (or num_rerank is 0), then reranking is turned off rerank_model_name: str | None + rerank_api_url: str | None rerank_provider_type: RerankerProvider | None rerank_api_key: str | None = None @@ -42,6 +43,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails": rerank_provider_type=search_settings.rerank_provider_type, rerank_api_key=search_settings.rerank_api_key, num_rerank=search_settings.num_rerank, + rerank_api_url=search_settings.rerank_api_url, ) @@ -81,6 +83,7 @@ def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings" num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, + rerank_api_url=search_settings.rerank_api_url, ) diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index ad3e19e149d..183c8729d67 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -209,7 +209,9 @@ def _get_sections(self) -> list[InferenceSection]: if inference_section is not None: expanded_inference_sections.append(inference_section) else: - logger.warning("Skipped creation of section, no chunks found") + logger.warning( + "Skipped creation of section for full docs, no chunks found" + ) self._retrieved_sections = expanded_inference_sections return expanded_inference_sections @@ -270,6 +272,11 @@ def _get_sections(self) -> list[InferenceSection]: (chunk.document_id, chunk.chunk_id): chunk for chunk in inference_chunks } + # In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks + doc_chunk_ind_to_chunk.update( + {(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks} + ) + # Build the surroundings for all of the initial retrieved chunks for chunk in retrieved_chunks: start_ind = max(0, chunk.chunk_id - above) @@ -360,10 +367,10 @@ def section_relevance(self) -> list[SectionRelevancePiece] | None: try: results = run_functions_in_parallel(function_calls=functions) self._section_relevance = list(results.values()) - except Exception: + except Exception as e: raise ValueError( - "An issue occured during the agentic evaluation proecss." - ) + "An issue occured during the agentic evaluation process." + ) from e elif self.search_query.evaluation_type == LLMEvaluationType.BASIC: if DISABLE_LLM_DOC_RELEVANCE: diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index 6a3d2dc2dcd..b4a1e48bd39 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -100,6 +100,7 @@ def semantic_reranking( model_name=rerank_settings.rerank_model_name, provider_type=rerank_settings.rerank_provider_type, api_key=rerank_settings.rerank_api_key, + api_url=rerank_settings.rerank_api_url, ) passages = [ @@ -253,8 +254,8 @@ def search_postprocessing( if not retrieved_sections: # Avoids trying to rerank an empty list which throws an error - yield [] - yield [] + yield cast(list[InferenceSection], []) + yield cast(list[SectionRelevancePiece], []) return rerank_task_id = None diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 31582f90819..30347464ff8 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -3,7 +3,6 @@ import nltk # type:ignore from nltk.corpus import stopwords # type:ignore -from nltk.stem import WordNetLemmatizer # type:ignore from nltk.tokenize import word_tokenize # type:ignore from sqlalchemy.orm import Session @@ -40,7 +39,7 @@ def download_nltk_data() -> None: resources = { "stopwords": "corpora/stopwords", - "wordnet": "corpora/wordnet", + # "wordnet": "corpora/wordnet", # Not in use "punkt": "tokenizers/punkt", } @@ -58,15 +57,16 @@ def download_nltk_data() -> None: def lemmatize_text(keywords: list[str]) -> list[str]: - try: - query = " ".join(keywords) - lemmatizer = WordNetLemmatizer() - word_tokens = word_tokenize(query) - lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens] - combined_keywords = list(set(keywords + lemmatized_words)) - return combined_keywords - except Exception: - return keywords + raise NotImplementedError("Lemmatization should not be used currently") + # try: + # query = " ".join(keywords) + # lemmatizer = WordNetLemmatizer() + # word_tokens = word_tokenize(query) + # lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens] + # combined_keywords = list(set(keywords + lemmatized_words)) + # return combined_keywords + # except Exception: + # return keywords def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]: diff --git a/backend/danswer/secondary_llm_flows/agentic_evaluation.py b/backend/danswer/secondary_llm_flows/agentic_evaluation.py index 3de9db00be6..03121e3cf1d 100644 --- a/backend/danswer/secondary_llm_flows/agentic_evaluation.py +++ b/backend/danswer/secondary_llm_flows/agentic_evaluation.py @@ -58,25 +58,30 @@ def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str: center_metadata=center_metadata_str, ) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = message_to_string(llm.invoke(filled_llm_prompt)) + try: + model_output = message_to_string(llm.invoke(filled_llm_prompt)) - # Search for the "Useful Analysis" section in the model output - # This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon, - # then any text up to "3. Final Relevance" - # The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines - # If no match is found, the entire model output is used as the analysis - analysis_match = re.search( - r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance", - model_output, - re.DOTALL, - ) - analysis = analysis_match.group(1).strip() if analysis_match else model_output + # Search for the "Useful Analysis" section in the model output + # This regex looks for "2. Useful Analysis" (case-insensitive) followed by an optional colon, + # then any text up to "3. Final Relevance" + # The (?i) flag makes it case-insensitive, and re.DOTALL allows the dot to match newlines + # If no match is found, the entire model output is used as the analysis + analysis_match = re.search( + r"(?i)2\.\s*useful analysis:?\s*(.+?)\n\n3\.\s*final relevance", + model_output, + re.DOTALL, + ) + analysis = analysis_match.group(1).strip() if analysis_match else model_output - # Get the last non-empty line - last_line = next( - (line for line in reversed(model_output.split("\n")) if line.strip()), "" - ) - relevant = last_line.strip().lower().startswith("true") + # Get the last non-empty line + last_line = next( + (line for line in reversed(model_output.split("\n")) if line.strip()), "" + ) + relevant = last_line.strip().lower().startswith("true") + except Exception as e: + logger.exception(f"An issue occured during the agentic evaluation process. {e}") + relevant = False + analysis = "" return SectionRelevancePiece( document_id=document_id, diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 802a14f42fa..f58a91016e0 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -3,12 +3,16 @@ from sqlalchemy.orm import Session +from danswer.configs.chat_configs import ENABLE_CONNECTOR_CLASSIFIER from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_unique_document_sources from danswer.db.engine import get_sqlalchemy_engine from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string +from danswer.natural_language_processing.search_nlp_models import ( + ConnectorClassificationModel, +) from danswer.prompts.constants import SOURCES_KEY from danswer.prompts.filter_extration import FILE_SOURCE_WARNING from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT @@ -42,11 +46,38 @@ def _sample_document_sources( return random.sample(valid_sources, num_sample) +def _sample_documents_using_custom_connector_classifier( + query: str, + valid_sources: list[DocumentSource], +) -> list[DocumentSource] | None: + query_joined = "".join(ch for ch in query.lower() if ch.isalnum()) + available_connectors = list( + filter( + lambda conn: conn.lower() in query_joined, + [item.value for item in valid_sources], + ) + ) + + if not available_connectors: + return None + + connectors = ConnectorClassificationModel().predict(query, available_connectors) + + return strings_to_document_sources(connectors) if connectors else None + + def extract_source_filter( query: str, llm: LLM, db_session: Session ) -> list[DocumentSource] | None: """Returns a list of valid sources for search or None if no specific sources were detected""" + valid_sources = fetch_unique_document_sources(db_session) + if not valid_sources: + return None + + if ENABLE_CONNECTOR_CLASSIFIER: + return _sample_documents_using_custom_connector_classifier(query, valid_sources) + def _get_source_filter_messages( query: str, valid_sources: list[DocumentSource], @@ -146,10 +177,6 @@ def _extract_source_filters_from_llm_out( logger.warning("LLM failed to provide a valid Source Filter output") return None - valid_sources = fetch_unique_document_sources(db_session) - if not valid_sources: - return None - messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) model_output = message_to_string(llm.invoke(filled_llm_prompt)) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 69ae9916348..97ed3a82812 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -1,7 +1,9 @@ +import math + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException -from pydantic import BaseModel +from fastapi import Query from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -19,20 +21,56 @@ from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import get_index_attempts_for_connector +from danswer.db.index_attempt import count_index_attempts_for_connector +from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id +from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CCPairFullInfo +from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata +from danswer.server.documents.models import PaginatedIndexAttempts from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() router = APIRouter(prefix="/manage") +@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts") +def get_cc_pair_index_attempts( + cc_pair_id: int, + page: int = Query(1, ge=1), + page_size: int = Query(10, ge=1, le=1000), + user: User | None = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> PaginatedIndexAttempts: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id, db_session, user, get_editable=False + ) + if not cc_pair: + raise HTTPException( + status_code=400, detail="CC Pair not found for current user permissions" + ) + total_count = count_index_attempts_for_connector( + db_session=db_session, + connector_id=cc_pair.connector_id, + ) + index_attempts = get_paginated_index_attempts_for_cc_pair_id( + db_session=db_session, + connector_id=cc_pair.connector_id, + page=page, + page_size=page_size, + ) + return PaginatedIndexAttempts.from_models( + index_attempt_models=index_attempts, + page=page, + total_pages=math.ceil(total_count / page_size), + ) + + @router.get("/admin/cc-pair/{cc_pair_id}") def get_cc_pair_full_info( cc_pair_id: int, @@ -56,11 +94,6 @@ def get_cc_pair_full_info( credential_id=cc_pair.credential_id, ) - index_attempts = get_index_attempts_for_connector( - db_session, - cc_pair.connector_id, - ) - document_count_info_list = list( get_document_cnts_for_cc_pairs( db_session=db_session, @@ -71,9 +104,20 @@ def get_cc_pair_full_info( document_count_info_list[0][-1] if document_count_info_list else 0 ) + latest_attempt = get_latest_index_attempt_for_cc_pair_id( + db_session=db_session, + connector_credential_pair_id=cc_pair.id, + secondary_index=False, + only_finished=False, + ) + return CCPairFullInfo.from_models( cc_pair_model=cc_pair, - index_attempt_models=list(index_attempts), + number_of_index_attempts=count_index_attempts_for_connector( + db_session=db_session, + connector_id=cc_pair.connector_id, + ), + last_index_attempt=latest_attempt, latest_deletion_attempt=get_deletion_attempt_snapshot( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, @@ -84,10 +128,6 @@ def get_cc_pair_full_info( ) -class CCStatusUpdateRequest(BaseModel): - status: ConnectorCredentialPairStatus - - @router.put("/admin/cc-pair/{cc_pair_id}/status") def update_cc_pair_status( cc_pair_id: int, @@ -157,11 +197,12 @@ def associate_credential_to_connector( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - if user and user.role != UserRole.ADMIN and metadata.is_public: - raise HTTPException( - status_code=400, - detail="Public connections cannot be created by non-admin users", - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=metadata.groups, + object_is_public=metadata.is_public, + ) try: response = add_credential_to_connector( @@ -170,7 +211,7 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, - is_public=metadata.is_public or True, + is_public=True if metadata.is_public is None else metadata.is_public, groups=metadata.groups, ) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 8d6b0ffc773..129a901ab0d 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -66,8 +66,10 @@ from danswer.db.engine import get_session from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair -from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pair +from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_latest_index_attempts +from danswer.db.index_attempt import get_latest_index_attempts_by_status +from danswer.db.models import IndexingStatus from danswer.db.models import User from danswer.db.models import UserRole from danswer.db.search_settings import get_current_search_settings @@ -75,13 +77,13 @@ from danswer.file_store.file_store import get_default_file_store from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthUrl -from danswer.server.documents.models import ConnectorBase from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import ConnectorSnapshot from danswer.server.documents.models import ConnectorUpdateRequest from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialSnapshot +from danswer.server.documents.models import FailedConnectorIndexingStatus from danswer.server.documents.models import FileUploadResponse from danswer.server.documents.models import GDriveCallback from danswer.server.documents.models import GmailCallback @@ -93,6 +95,7 @@ from danswer.server.documents.models import RunConnectorRequest from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -376,6 +379,95 @@ def upload_files( return FileUploadResponse(file_paths=deduped_file_paths) +# Retrieves most recent failure cases for connectors that are currently failing +@router.get("/admin/connector/failed-indexing-status") +def get_currently_failed_indexing_status( + secondary_index: bool = False, + user: User = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), + get_editable: bool = Query( + False, description="If true, return editable document sets" + ), +) -> list[FailedConnectorIndexingStatus]: + # Get the latest failed indexing attempts + latest_failed_indexing_attempts = get_latest_index_attempts_by_status( + secondary_index=secondary_index, + db_session=db_session, + status=IndexingStatus.FAILED, + ) + + # Get the latest successful indexing attempts + latest_successful_indexing_attempts = get_latest_index_attempts_by_status( + secondary_index=secondary_index, + db_session=db_session, + status=IndexingStatus.SUCCESS, + ) + + # Get all connector credential pairs + cc_pairs = get_connector_credential_pairs( + db_session=db_session, + user=user, + get_editable=get_editable, + ) + + # Filter out failed attempts that have a more recent successful attempt + filtered_failed_attempts = [ + failed_attempt + for failed_attempt in latest_failed_indexing_attempts + if not any( + success_attempt.connector_credential_pair_id + == failed_attempt.connector_credential_pair_id + and success_attempt.time_updated > failed_attempt.time_updated + for success_attempt in latest_successful_indexing_attempts + ) + ] + + # Filter cc_pairs to include only those with failed attempts + cc_pairs = [ + cc_pair + for cc_pair in cc_pairs + if any( + attempt.connector_credential_pair == cc_pair + for attempt in filtered_failed_attempts + ) + ] + + # Create a mapping of cc_pair_id to its latest failed index attempt + cc_pair_to_latest_index_attempt = { + attempt.connector_credential_pair_id: attempt + for attempt in filtered_failed_attempts + } + + indexing_statuses = [] + + for cc_pair in cc_pairs: + # Skip DefaultCCPair + if cc_pair.name == "DefaultCCPair": + continue + + latest_index_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id) + + indexing_statuses.append( + FailedConnectorIndexingStatus( + cc_pair_id=cc_pair.id, + name=cc_pair.name, + error_msg=( + latest_index_attempt.error_msg if latest_index_attempt else None + ), + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + is_deletable=check_deletion_attempt_is_allowed( + connector_credential_pair=cc_pair, + db_session=db_session, + allow_scheduled=True, + ) + is None, + ) + ) + + return indexing_statuses + + @router.get("/admin/connector/indexing-status") def get_connector_indexing_status( secondary_index: bool = False, @@ -387,7 +479,12 @@ def get_connector_indexing_status( ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - # TODO: make this one query + # NOTE: If the connector is deleting behind the scenes, + # accessing cc_pairs can be inconsistent and members like + # connector or credential may be None. + # Additional checks are done to make sure the connector and credential still exists. + # TODO: make this one query ... possibly eager load or wrap in a read transaction + # to avoid the complexity of trying to error check throughout the function cc_pairs = get_connector_credential_pairs( db_session=db_session, user=user, @@ -440,14 +537,19 @@ def get_connector_indexing_status( connector = cc_pair.connector credential = cc_pair.credential + if not connector or not credential: + # This may happen if background deletion is happening + continue + latest_index_attempt = cc_pair_to_latest_index_attempt.get( (connector.id, credential.id) ) - latest_finished_attempt = get_latest_finished_index_attempt_for_cc_pair( + latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id( + db_session=db_session, connector_credential_pair_id=cc_pair.id, secondary_index=secondary_index, - db_session=db_session, + only_finished=True, ) indexing_statuses.append( @@ -514,35 +616,6 @@ def _validate_connector_allowed(source: DocumentSource) -> None: ) -def _check_connector_permissions( - connector_data: ConnectorUpdateRequest, user: User | None -) -> ConnectorBase: - """ - This is not a proper permission check, but this should prevent curators creating bad situations - until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections) - """ - if user and user.role != UserRole.ADMIN: - if connector_data.is_public: - raise HTTPException( - status_code=400, - detail="Public connectors can only be created by admins", - ) - if not connector_data.groups: - raise HTTPException( - status_code=400, - detail="Connectors created by curators must have groups", - ) - return ConnectorBase( - name=connector_data.name, - source=connector_data.source, - input_type=connector_data.input_type, - connector_specific_config=connector_data.connector_specific_config, - refresh_freq=connector_data.refresh_freq, - prune_freq=connector_data.prune_freq, - indexing_start=connector_data.indexing_start, - ) - - @router.post("/admin/connector") def create_connector_from_model( connector_data: ConnectorUpdateRequest, @@ -551,12 +624,19 @@ def create_connector_from_model( ) -> ObjectCreationIdResponse: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() return create_connector( db_session=db_session, connector_data=connector_base, ) except ValueError as e: + logger.error(f"Error creating connector: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -607,12 +687,18 @@ def create_connector_with_mock_credential( def update_connector_from_model( connector_id: int, connector_data: ConnectorUpdateRequest, - user: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -642,7 +728,7 @@ def update_connector_from_model( @router.delete("/admin/connector/{connector_id}", response_model=StatusResponse[int]) def delete_connector_by_id( connector_id: int, - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: try: diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index ba30b65f2f9..3d965481bf5 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -7,7 +7,6 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.credentials import alter_credential from danswer.db.credentials import create_credential from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE @@ -20,7 +19,6 @@ from danswer.db.engine import get_session from danswer.db.models import DocumentSource from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.server.documents.models import CredentialSnapshot @@ -28,6 +26,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -80,7 +79,7 @@ def get_cc_source_full_info( ] -@router.get("/credentials/{id}") +@router.get("/credential/{id}") def list_credentials_by_id( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), @@ -105,7 +104,7 @@ def delete_credential_by_id_admin( ) -@router.put("/admin/credentials/swap") +@router.put("/admin/credential/swap") def swap_credentials_for_connector( credential_swap_req: CredentialSwapRequest, user: User | None = Depends(current_user), @@ -131,14 +130,12 @@ def create_credential_from_model( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: - if ( - user - and user.role != UserRole.ADMIN - and not _ignore_credential_permissions(credential_info.source) - ): - validate_curator_request( - groups=credential_info.groups, - is_public=credential_info.curator_public, + if not _ignore_credential_permissions(credential_info.source): + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=credential_info.groups, + object_is_public=credential_info.curator_public, ) credential = create_credential(credential_info, user, db_session) @@ -179,7 +176,7 @@ def get_credential_by_id( return CredentialSnapshot.from_credential_db_model(credential) -@router.put("/admin/credentials/{credential_id}") +@router.put("/admin/credential/{credential_id}") def update_credential_data( credential_id: int, credential_update: CredentialDataUpdateRequest, diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ba011afc196..517813892b8 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from pydantic import Field +from pydantic import model_validator from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource @@ -48,9 +49,12 @@ class ConnectorBase(BaseModel): class ConnectorUpdateRequest(ConnectorBase): - is_public: bool | None = None + is_public: bool = True groups: list[int] = Field(default_factory=list) + def to_connector_base(self) -> ConnectorBase: + return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"})) + class ConnectorSnapshot(ConnectorBase): id: int @@ -103,11 +107,6 @@ class CredentialSnapshot(CredentialBase): user_id: UUID | None time_created: datetime time_updated: datetime - name: str | None - source: DocumentSource - credential_json: dict[str, Any] - admin_public: bool - curator_public: bool @classmethod def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot": @@ -187,6 +186,28 @@ def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError": ) +class PaginatedIndexAttempts(BaseModel): + index_attempts: list[IndexAttemptSnapshot] + page: int + total_pages: int + + @classmethod + def from_models( + cls, + index_attempt_models: list[IndexAttempt], + page: int, + total_pages: int, + ) -> "PaginatedIndexAttempts": + return cls( + index_attempts=[ + IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model) + for index_attempt_model in index_attempt_models + ], + page=page, + total_pages=total_pages, + ) + + class CCPairFullInfo(BaseModel): id: int name: str @@ -194,20 +215,38 @@ class CCPairFullInfo(BaseModel): num_docs_indexed: int connector: ConnectorSnapshot credential: CredentialSnapshot - index_attempts: list[IndexAttemptSnapshot] + number_of_index_attempts: int + last_index_attempt_status: IndexingStatus | None latest_deletion_attempt: DeletionAttemptSnapshot | None is_public: bool is_editable_for_current_user: bool + deletion_failure_message: str | None @classmethod def from_models( cls, cc_pair_model: ConnectorCredentialPair, - index_attempt_models: list[IndexAttempt], latest_deletion_attempt: DeletionAttemptSnapshot | None, + number_of_index_attempts: int, + last_index_attempt: IndexAttempt | None, num_docs_indexed: int, # not ideal, but this must be computed separately is_editable_for_current_user: bool, ) -> "CCPairFullInfo": + # figure out if we need to artificially deflate the number of docs indexed. + # This is required since the total number of docs indexed by a CC Pair is + # updated before the new docs for an indexing attempt. If we don't do this, + # there is a mismatch between these two numbers which may confuse users. + last_indexing_status = last_index_attempt.status if last_index_attempt else None + if ( + last_indexing_status == IndexingStatus.SUCCESS + and number_of_index_attempts == 1 + and last_index_attempt + and last_index_attempt.new_docs_indexed + ): + num_docs_indexed = ( + last_index_attempt.new_docs_indexed if last_index_attempt else 0 + ) + return cls( id=cc_pair_model.id, name=cc_pair_model.name, @@ -219,16 +258,26 @@ def from_models( credential=CredentialSnapshot.from_credential_db_model( cc_pair_model.credential ), - index_attempts=[ - IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model) - for index_attempt_model in index_attempt_models - ], + number_of_index_attempts=number_of_index_attempts, + last_index_attempt_status=last_indexing_status, latest_deletion_attempt=latest_deletion_attempt, is_public=cc_pair_model.is_public, is_editable_for_current_user=is_editable_for_current_user, + deletion_failure_message=cc_pair_model.deletion_failure_message, ) +class FailedConnectorIndexingStatus(BaseModel): + """Simplified version of ConnectorIndexingStatus for failed indexing attempts""" + + cc_pair_id: int + name: str | None + error_msg: str | None + is_deletable: bool + connector_id: int + credential_id: int + + class ConnectorIndexingStatus(BaseModel): """Represents the latest indexing status of a connector""" @@ -261,6 +310,10 @@ class ConnectorCredentialPairMetadata(BaseModel): groups: list[int] = Field(default_factory=list) +class CCStatusUpdateRequest(BaseModel): + status: ConnectorCredentialPairStatus + + class ConnectorCredentialPairDescriptor(BaseModel): id: int name: str | None = None @@ -307,8 +360,18 @@ class GoogleServiceAccountKey(BaseModel): class GoogleServiceAccountCredentialRequest(BaseModel): - google_drive_delegated_user: str | None # email of user to impersonate - gmail_delegated_user: str | None # email of user to impersonate + google_drive_delegated_user: str | None = None # email of user to impersonate + gmail_delegated_user: str | None = None # email of user to impersonate + + @model_validator(mode="after") + def check_user_delegation(self) -> "GoogleServiceAccountCredentialRequest": + if (self.google_drive_delegated_user is None) == ( + self.gmail_delegated_user is None + ): + raise ValueError( + "Exactly one of google_drive_delegated_user or gmail_delegated_user must be set" + ) + return self class FileUploadResponse(BaseModel): diff --git a/backend/danswer/server/features/document_set/api.py b/backend/danswer/server/features/document_set/api.py index d1eff082891..c9cea2cf2a2 100644 --- a/backend/danswer/server/features/document_set/api.py +++ b/backend/danswer/server/features/document_set/api.py @@ -6,7 +6,6 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.document_set import check_document_sets_are_public from danswer.db.document_set import fetch_all_document_sets_for_user from danswer.db.document_set import insert_document_set @@ -14,12 +13,12 @@ from danswer.db.document_set import update_document_set from danswer.db.engine import get_session from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.features.document_set.models import CheckDocSetPublicRequest from danswer.server.features.document_set.models import CheckDocSetPublicResponse from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from ee.danswer.db.user_group import validate_user_creation_permissions router = APIRouter(prefix="/manage") @@ -31,11 +30,12 @@ def create_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> int: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_creation_request.groups, - is_public=document_set_creation_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_creation_request.groups, + object_is_public=document_set_creation_request.is_public, + ) try: document_set_db_model, _ = insert_document_set( document_set_creation_request=document_set_creation_request, @@ -53,11 +53,12 @@ def patch_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_update_request.groups, - is_public=document_set_update_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_update_request.groups, + object_is_public=document_set_update_request.is_public, + ) try: update_document_set( document_set_update_request=document_set_update_request, diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 0ac90ba8d11..f45d5c38529 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -77,16 +77,10 @@ def document_boost_update( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - update_document_boost( db_session=db_session, document_id=boost_update.document_id, boost=boost_update.boost, - document_index=document_index, user=user, ) return StatusResponse(success=True, message="Updated document boost") @@ -166,10 +160,14 @@ def create_deletion_attempt_for_connector_id( get_editable=True, ) if cc_pair is None: + error = ( + f"Connector with ID '{connector_id}' and credential ID " + f"'{credential_id}' does not exist. Has it already been deleted?" + ) + logger.error(error) raise HTTPException( status_code=404, - detail=f"Connector with ID '{connector_id}' and credential ID " - f"'{credential_id}' does not exist. Has it already been deleted?", + detail=error, ) # Cancel any scheduled indexing attempts diff --git a/backend/danswer/server/manage/embedding/api.py b/backend/danswer/server/manage/embedding/api.py index 90fa69401c2..eac872810ef 100644 --- a/backend/danswer/server/manage/embedding/api.py +++ b/backend/danswer/server/manage/embedding/api.py @@ -9,7 +9,9 @@ from danswer.db.llm import remove_embedding_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.models import User +from danswer.db.search_settings import get_all_search_settings from danswer.db.search_settings import get_current_db_embedding_provider +from danswer.indexing.models import EmbeddingModelDetail from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest @@ -20,6 +22,7 @@ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType + logger = setup_logger() @@ -37,11 +40,12 @@ def test_embedding_configuration( server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, api_key=test_llm_request.api_key, + api_url=test_llm_request.api_url, provider_type=test_llm_request.provider_type, + model_name=test_llm_request.model_name, normalize=False, query_prefix=None, passage_prefix=None, - model_name=None, ) test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY) @@ -56,6 +60,15 @@ def test_embedding_configuration( raise HTTPException(status_code=400, detail=error_msg) +@admin_router.get("", response_model=list[EmbeddingModelDetail]) +def list_embedding_models( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[EmbeddingModelDetail]: + search_settings = get_all_search_settings(db_session) + return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings] + + @admin_router.get("/embedding-provider") def list_embedding_providers( _: User | None = Depends(current_admin_user), diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index 132d311413c..b4ca7862b55 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -8,14 +8,21 @@ from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel +class SearchSettingsDeleteRequest(BaseModel): + search_settings_id: int + + class TestEmbeddingRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None + model_name: str | None = None class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None @classmethod def from_request( @@ -24,9 +31,11 @@ def from_request( return cls( provider_type=cloud_provider_model.provider_type, api_key=cloud_provider_model.api_key, + api_url=cloud_provider_model.api_url, ) class CloudEmbeddingProviderCreationRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 9ea9fe927db..4e57ec7bc35 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -121,7 +121,7 @@ def put_llm_provider( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> FullLLMProvider: - return upsert_llm_provider(db_session, llm_provider) + return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session) @admin_router.delete("/provider/{provider_id}") @@ -139,7 +139,7 @@ def set_provider_as_default( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: - update_default_provider(db_session, provider_id) + update_default_provider(provider_id=provider_id, db_session=db_session) """Endpoints for all""" diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index db483eff5da..c8433467f6c 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -14,6 +14,7 @@ from danswer.db.models import IndexModelStatus from danswer.db.models import User from danswer.db.search_settings import create_search_settings +from danswer.db.search_settings import delete_search_settings from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_embedding_provider_from_provider_type from danswer.db.search_settings import get_secondary_search_settings @@ -23,6 +24,7 @@ from danswer.natural_language_processing.search_nlp_models import clean_model_name from danswer.search.models import SavedSearchSettings from danswer.search.models import SearchSettingsCreationRequest +from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest from danswer.server.manage.models import FullModelVersionResponse from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger @@ -45,7 +47,7 @@ def set_new_search_settings( if search_settings_new.index_name: logger.warning("Index name was specified by request, this is not suggested") - # Validate cloud provider exists + # Validate cloud provider exists or create new LiteLLM provider if search_settings_new.provider_type is not None: cloud_provider = get_embedding_provider_from_provider_type( db_session, provider_type=search_settings_new.provider_type @@ -97,6 +99,7 @@ def set_new_search_settings( primary_index_name=search_settings.index_name, secondary_index_name=new_search_settings.index_name, ) + document_index.ensure_indices_exist( index_embedding_dim=search_settings.model_dim, secondary_index_embedding_dim=new_search_settings.model_dim, @@ -132,8 +135,23 @@ def cancel_new_embedding( ) +@router.delete("/delete-search-settings") +def delete_search_settings_endpoint( + deletion_request: SearchSettingsDeleteRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + try: + delete_search_settings( + db_session=db_session, + search_settings_id=deletion_request.search_settings_id, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + @router.get("/get-current-search-settings") -def get_curr_search_settings( +def get_current_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings: @@ -142,7 +160,7 @@ def get_curr_search_settings( @router.get("/get-secondary-search-settings") -def get_sec_search_settings( +def get_secondary_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings | None: diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index d2fd981b5b5..96c79b4cbe7 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -213,6 +213,52 @@ def deactivate_user( db_session.commit() +@router.delete("/manage/admin/delete-user") +async def delete_user( + user_email: UserByEmail, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + user_to_delete = get_user_by_email( + email=user_email.user_email, db_session=db_session + ) + if not user_to_delete: + raise HTTPException(status_code=404, detail="User not found") + + if user_to_delete.is_active is True: + logger.warning( + "{} must be deactivated before deleting".format(user_to_delete.email) + ) + raise HTTPException( + status_code=400, detail="User must be deactivated before deleting" + ) + + # Detach the user from the current session + db_session.expunge(user_to_delete) + + try: + # Delete related OAuthAccounts first + for oauth_account in user_to_delete.oauth_accounts: + db_session.delete(oauth_account) + + db_session.delete(user_to_delete) + db_session.commit() + + # NOTE: edge case may exist with race conditions + # with this `invited user` scheme generally. + user_emails = get_invited_users() + remaining_users = [ + user for user in user_emails if user != user_email.user_email + ] + write_invited_users(remaining_users) + + logger.info(f"Deleted user {user_to_delete.email}") + except Exception as e: + db_session.rollback() + logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}") + raise HTTPException(status_code=500, detail="Error deleting user") + + @router.patch("/manage/admin/activate-user") def activate_user( user_email: UserByEmail, diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a37758336a2..20ae7124fa1 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -269,7 +269,10 @@ def delete_chat_session_by_id( db_session: Session = Depends(get_session), ) -> None: user_id = user.id if user is not None else None - delete_chat_session(user_id, session_id, db_session) + try: + delete_chat_session(user_id, session_id, db_session) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) async def is_disconnected(request: Request) -> Callable[[], bool]: diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 704b16d5eaa..e20de5a3027 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -11,8 +11,8 @@ from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_sessions_by_user -from danswer.db.chat import get_first_messages_for_chat_sessions from danswer.db.chat import get_search_docs_for_chat_message +from danswer.db.chat import get_valid_messages_from_query_sessions from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session @@ -142,18 +142,20 @@ def get_user_search_sessions( raise HTTPException( status_code=404, detail="Chat session does not exist or has been deleted" ) - + # Extract IDs from search sessions search_session_ids = [chat.id for chat in search_sessions] - first_messages = get_first_messages_for_chat_sessions( + # Fetch first messages for each session, only including those with documents + sessions_with_documents = get_valid_messages_from_query_sessions( search_session_ids, db_session ) - first_messages_dict = dict(first_messages) + sessions_with_documents_dict = dict(sessions_with_documents) + # Prepare response with detailed information for each valid search session response = ChatSessionsResponse( sessions=[ ChatSessionDetails( id=search.id, - name=first_messages_dict.get(search.id, search.description), + name=sessions_with_documents_dict[search.id], persona_id=search.persona_id, time_created=search.time_created.isoformat(), shared_status=search.shared_status, @@ -161,8 +163,11 @@ def get_user_search_sessions( current_alternate_model=search.current_alternate_model, ) for search in search_sessions + if search.id + in sessions_with_documents_dict # Only include sessions with documents ] ) + return response diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 3330f6cc5ff..5b8564c3d3a 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -66,7 +66,7 @@ def fetch_settings( return UserSettings( **general_settings.model_dump(), notifications=user_notifications, - needs_reindexing=needs_reindexing + needs_reindexing=needs_reindexing, ) diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index bf535661878..53ed5b426ba 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -1,9 +1,31 @@ import json +from datetime import datetime from typing import Any -def get_json_line(json_dict: dict) -> str: - return json.dumps(json_dict) + "\n" +class DateTimeEncoder(json.JSONEncoder): + """Custom JSON encoder that converts datetime objects to ISO format strings.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + return super().default(obj) + + +def get_json_line( + json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder +) -> str: + """ + Convert a dictionary to a JSON string with datetime handling, and add a newline. + + Args: + json_dict: The dictionary to be converted to JSON. + encoder: JSON encoder class to use, defaults to DateTimeEncoder. + + Returns: + A JSON string representation of the input dictionary with a newline character. + """ + return json.dumps(json_dict, cls=encoder) + "\n" def mask_string(sensitive_str: str) -> str: diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 99b2ae3bbb6..1bfecef7ce4 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -146,7 +146,6 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None: db_session.commit() logger.notice("Completed adding SearchTool to relevant Personas.") - _built_in_tools_cache: dict[int, Type[Tool]] | None = None diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/internet_search/internet_search_tool.py index 2640afcdf83..3012eb465f4 100644 --- a/backend/danswer/tools/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/internet_search/internet_search_tool.py @@ -20,7 +20,7 @@ from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.tools.internet_search.models import InternetSearchResponse from danswer.tools.internet_search.models import InternetSearchResult -from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS_ID from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.logger import setup_logger @@ -224,7 +224,7 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: ] yield ToolResponse( - id=FINAL_CONTEXT_DOCUMENTS, + id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs, ) diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 13d3a304b06..cbfaf4f3d92 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -45,7 +45,7 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" SEARCH_DOC_CONTENT_ID = "search_doc_content" SECTION_RELEVANCE_LIST_ID = "section_relevance_list" -FINAL_CONTEXT_DOCUMENTS = "final_context_documents" +FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" SEARCH_EVALUATION_ID = "llm_doc_eval" @@ -179,7 +179,7 @@ def build_tool_message_content( self, *args: ToolResponse ) -> str | list[str | dict[str, Any]]: final_context_docs_response = next( - response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS + response for response in args if response.id == FINAL_CONTEXT_DOCUMENTS_ID ) final_context_docs = cast(list[LlmDoc], final_context_docs_response.response) @@ -260,7 +260,7 @@ def _build_response_for_specified_sections( for section in final_context_sections ] - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["query"]) @@ -343,12 +343,12 @@ def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: llm_doc_from_inference_section(section) for section in pruned_sections ] - yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs) + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) def final_result(self, *args: ToolResponse) -> JSON_ro: final_docs = cast( list[LlmDoc], - next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS), + next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS_ID), ) # NOTE: need to do this json.loads(doc.json()) stuff because there are some # subfields that are not serializable by default (datetime) diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index f962c214a03..58b94bdb0c8 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from collections.abc import Generator from typing import Any @@ -47,7 +48,7 @@ def tool_final_result(self) -> ToolCallFinalResult: def check_which_tools_should_run_for_non_tool_calling_llm( tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM ) -> list[dict[str, Any] | None]: - tool_args_list = [ + tool_args_list: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ (tool.get_args_for_non_tool_calling_llm, (query, history, llm)) for tool in tools ] diff --git a/backend/danswer/utils/gpu_utils.py b/backend/danswer/utils/gpu_utils.py new file mode 100644 index 00000000000..70a3dbc2c95 --- /dev/null +++ b/backend/danswer/utils/gpu_utils.py @@ -0,0 +1,30 @@ +import requests +from retry import retry + +from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import INDEXING_MODEL_SERVER_PORT +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + +logger = setup_logger() + + +@retry(tries=5, delay=5) +def gpu_status_request(indexing: bool = True) -> bool: + if indexing: + model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}" + else: + model_server_url = f"{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" + + if "http" not in model_server_url: + model_server_url = f"http://{model_server_url}" + + try: + response = requests.get(f"{model_server_url}/api/gpu-status", timeout=10) + response.raise_for_status() + gpu_status = response.json() + return gpu_status["gpu_available"] + except requests.RequestException as e: + logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}") + raise # Re-raise exception to trigger a retry diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index a7751ca3dc7..9489a6244ff 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -19,14 +19,22 @@ class IndexAttemptSingleton: main background job (scheduler), etc. this will not be used.""" _INDEX_ATTEMPT_ID: None | int = None + _CONNECTOR_CREDENTIAL_PAIR_ID: None | int = None @classmethod def get_index_attempt_id(cls) -> None | int: return cls._INDEX_ATTEMPT_ID @classmethod - def set_index_attempt_id(cls, index_attempt_id: int) -> None: + def get_connector_credential_pair_id(cls) -> None | int: + return cls._CONNECTOR_CREDENTIAL_PAIR_ID + + @classmethod + def set_cc_and_index_id( + cls, index_attempt_id: int, connector_credential_pair_id: int + ) -> None: cls._INDEX_ATTEMPT_ID = index_attempt_id + cls._CONNECTOR_CREDENTIAL_PAIR_ID = connector_credential_pair_id def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int: @@ -50,9 +58,14 @@ def process( # If this is an indexing job, add the attempt ID to the log message # This helps filter the logs for this specific indexing attempt_id = IndexAttemptSingleton.get_index_attempt_id() + cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id() + if attempt_id is not None: msg = f"[Attempt ID: {attempt_id}] {msg}" + if cc_pair_id is not None: + msg = f"[CC Pair ID: {cc_pair_id}] {msg}" + # For Slack Bot, logs the channel relevant to the request channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None if channel_id: diff --git a/backend/danswer/utils/telemetry.py b/backend/danswer/utils/telemetry.py index 80fcba65a16..d8a021877e6 100644 --- a/backend/danswer/utils/telemetry.py +++ b/backend/danswer/utils/telemetry.py @@ -4,13 +4,20 @@ from typing import cast import requests +from sqlalchemy.orm import Session from danswer.configs.app_configs import DISABLE_TELEMETRY +from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.constants import KV_CUSTOMER_UUID_KEY +from danswer.configs.constants import KV_INSTANCE_DOMAIN_KEY +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import User from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError -DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry" +_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.danswer.ai/anonymous_telemetry" +_CACHED_UUID: str | None = None +_CACHED_INSTANCE_DOMAIN: str | None = None class RecordType(str, Enum): @@ -22,13 +29,42 @@ class RecordType(str, Enum): def get_or_generate_uuid() -> str: + global _CACHED_UUID + + if _CACHED_UUID is not None: + return _CACHED_UUID + + kv_store = get_dynamic_config_store() + + try: + _CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) + except ConfigNotFoundError: + _CACHED_UUID = str(uuid.uuid4()) + kv_store.store(KV_CUSTOMER_UUID_KEY, _CACHED_UUID, encrypt=True) + + return _CACHED_UUID + + +def _get_or_generate_instance_domain() -> str | None: + global _CACHED_INSTANCE_DOMAIN + + if _CACHED_INSTANCE_DOMAIN is not None: + return _CACHED_INSTANCE_DOMAIN + kv_store = get_dynamic_config_store() + try: - return cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) + _CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY)) except ConfigNotFoundError: - customer_id = str(uuid.uuid4()) - kv_store.store(KV_CUSTOMER_UUID_KEY, customer_id, encrypt=True) - return customer_id + with Session(get_sqlalchemy_engine()) as db_session: + first_user = db_session.query(User).first() + if first_user: + _CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1] + kv_store.store( + KV_INSTANCE_DOMAIN_KEY, _CACHED_INSTANCE_DOMAIN, encrypt=True + ) + + return _CACHED_INSTANCE_DOMAIN def optional_telemetry( @@ -41,16 +77,19 @@ def optional_telemetry( def telemetry_logic() -> None: try: + customer_uuid = get_or_generate_uuid() payload = { "data": data, "record": record_type, # If None then it's a flow that doesn't include a user # For cases where the User itself is None, a string is provided instead "user_id": user_id, - "customer_uuid": get_or_generate_uuid(), + "customer_uuid": customer_uuid, } + if ENTERPRISE_EDITION_ENABLED: + payload["instance_domain"] = _get_or_generate_instance_domain() requests.post( - DANSWER_TELEMETRY_ENDPOINT, + _DANSWER_TELEMETRY_ENDPOINT, headers={"Content-Type": "application/json"}, json=payload, ) diff --git a/backend/danswer/utils/text_processing.py b/backend/danswer/utils/text_processing.py index b0fbcdfa1e9..134859d4e74 100644 --- a/backend/danswer/utils/text_processing.py +++ b/backend/danswer/utils/text_processing.py @@ -43,6 +43,35 @@ def replace_whitespaces_w_space(s: str) -> str: return re.sub(r"\s", " ", s) +# Function to remove punctuation from a string +def remove_punctuation(s: str) -> str: + return s.translate(str.maketrans("", "", string.punctuation)) + + +def escape_quotes(original_json_str: str) -> str: + result = [] + in_string = False + for i, char in enumerate(original_json_str): + if char == '"': + if not in_string: + in_string = True + result.append(char) + else: + next_char = ( + original_json_str[i + 1] if i + 1 < len(original_json_str) else None + ) + if result and result[-1] == "\\": + result.append(char) + elif next_char not in [",", ":", "}", "\n"]: + result.append("\\" + char) + else: + result.append(char) + in_string = False + else: + result.append(char) + return "".join(result) + + def extract_embedded_json(s: str) -> dict: first_brace_index = s.find("{") last_brace_index = s.rfind("}") @@ -50,7 +79,15 @@ def extract_embedded_json(s: str) -> dict: if first_brace_index == -1 or last_brace_index == -1: raise ValueError("No valid json found") - return json.loads(s[first_brace_index : last_brace_index + 1], strict=False) + json_str = s[first_brace_index : last_brace_index + 1] + try: + return json.loads(json_str, strict=False) + + except json.JSONDecodeError: + try: + return json.loads(escape_quotes(json_str), strict=False) + except json.JSONDecodeError as e: + raise ValueError("Failed to parse JSON, even after escaping quotes") from e def clean_up_code_blocks(model_out_raw: str) -> str: diff --git a/backend/danswer/utils/variable_functionality.py b/backend/danswer/utils/variable_functionality.py index 97c6592601e..55f296aa8e7 100644 --- a/backend/danswer/utils/variable_functionality.py +++ b/backend/danswer/utils/variable_functionality.py @@ -31,6 +31,28 @@ def set_is_ee_based_on_env_variable() -> None: @functools.lru_cache(maxsize=128) def fetch_versioned_implementation(module: str, attribute: str) -> Any: + """ + Fetches a versioned implementation of a specified attribute from a given module. + This function first checks if the application is running in an Enterprise Edition (EE) + context. If so, it attempts to import the attribute from the EE-specific module. + If the module or attribute is not found, it falls back to the default module or + raises the appropriate exception depending on the context. + + Args: + module (str): The name of the module from which to fetch the attribute. + attribute (str): The name of the attribute to fetch from the module. + + Returns: + Any: The fetched implementation of the attribute. + + Raises: + ModuleNotFoundError: If the module cannot be found and the error is not related to + the Enterprise Edition fallback logic. + + Logs: + Logs debug information about the fetching process and warnings if the versioned + implementation cannot be found or loaded. + """ logger.debug("Fetching versioned implementation for %s.%s", module, attribute) is_ee = global_version.get_is_ee_version() @@ -66,6 +88,19 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any: def fetch_versioned_implementation_with_fallback( module: str, attribute: str, fallback: T ) -> T: + """ + Attempts to fetch a versioned implementation of a specified attribute from a given module. + If the attempt fails (e.g., due to an import error or missing attribute), the function logs + a warning and returns the provided fallback implementation. + + Args: + module (str): The name of the module from which to fetch the attribute. + attribute (str): The name of the attribute to fetch from the module. + fallback (T): The fallback implementation to return if fetching the attribute fails. + + Returns: + T: The fetched implementation if successful, otherwise the provided fallback. + """ try: return fetch_versioned_implementation(module, attribute) except Exception: @@ -73,4 +108,14 @@ def fetch_versioned_implementation_with_fallback( def noop_fallback(*args: Any, **kwargs: Any) -> None: - pass + """ + A no-op (no operation) fallback function that accepts any arguments but does nothing. + This is often used as a default or placeholder callback function. + + Args: + *args (Any): Positional arguments, which are ignored. + **kwargs (Any): Keyword arguments, which are ignored. + + Returns: + None + """ diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py index c2b05ee881f..2b3cdb7a9dc 100644 --- a/backend/ee/danswer/access/access.py +++ b/backend/ee/danswer/access/access.py @@ -11,6 +11,17 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user +def _get_access_for_document( + document_id: str, + db_session: Session, +) -> DocumentAccess: + id_to_access = _get_access_for_documents([document_id], db_session) + if len(id_to_access) == 0: + return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + + return next(iter(id_to_access.values())) + + def _get_access_for_documents( document_ids: list[str], db_session: Session, diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 403adbd74e1..2b4c96ccb1e 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -1,28 +1,18 @@ from datetime import timedelta -from typing import Any -from celery.signals import beat_init -from celery.signals import worker_init from sqlalchemy.orm import Session from danswer.background.celery.celery_app import celery_app from danswer.background.task_utils import build_celery_task_wrapper from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME from danswer.db.chat import delete_chat_sessions_older_than from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import init_sqlalchemy_engine from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check -from ee.danswer.background.celery_utils import should_sync_user_groups from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import name_user_group_sync_task -from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report -from ee.danswer.user_groups.sync import sync_user_groups logger = setup_logger() @@ -30,17 +20,6 @@ global_version.set_ee() -@build_celery_task_wrapper(name_user_group_sync_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_user_group_task(user_group_id: int) -> None: - with Session(get_sqlalchemy_engine()) as db_session: - # actual sync logic - try: - sync_user_groups(user_group_id=user_group_id, db_session=db_session) - except Exception as e: - logger.exception(f"Failed to sync user group - {e}") - - @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task(retention_limit_days: int) -> None: @@ -51,8 +30,6 @@ def perform_ttl_management_task(retention_limit_days: int) -> None: ##### # Periodic Tasks ##### - - @celery_app.task( name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, @@ -69,24 +46,6 @@ def check_ttl_management_task() -> None: ) -@celery_app.task( - name="check_for_user_groups_sync_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_for_user_groups_sync_task() -> None: - """Runs periodically to check if any user groups are out of sync - Creates a task to sync the user group if needed""" - with Session(get_sqlalchemy_engine()) as db_session: - # check if any document sets are not synced - user_groups = fetch_user_groups(db_session=db_session, only_current=False) - for user_group in user_groups: - if should_sync_user_groups(user_group, db_session): - logger.info(f"User Group {user_group.id} is not synced. Syncing now!") - sync_user_group_task.apply_async( - kwargs=dict(user_group_id=user_group.id), - ) - - @celery_app.task( name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, @@ -101,25 +60,11 @@ def autogenerate_usage_report_task() -> None: ) -@beat_init.connect -def on_beat_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME) - - -@worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME) - - ##### # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "check-for-user-group-sync": { - "task": "check_for_user_groups_sync_task", - "schedule": timedelta(seconds=5), - }, - "autogenerate_usage_report": { + "autogenerate-usage-report": { "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 0134f6642f7..34190255f5a 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -1,27 +1,13 @@ from sqlalchemy.orm import Session -from danswer.db.models import UserGroup from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import name_user_group_sync_task logger = setup_logger() -def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool: - if user_group.is_up_to_date: - return False - task_name = name_user_group_sync_task(user_group.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): - logger.info("TTL check is already being performed. Skipping.") - return False - return True - - def should_perform_chat_ttl_check( retention_limit_days: int | None, db_session: Session ) -> bool: diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 9d172c5d716..ab666f747b5 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -2,8 +2,10 @@ from operator import and_ from uuid import UUID +from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import func +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session @@ -30,16 +32,75 @@ logger = setup_logger() +def validate_user_creation_permissions( + db_session: Session, + user: User | None, + target_group_ids: list[int] | None, + object_is_public: bool | None, +) -> None: + """ + All admin actions are allowed. + Prevents non-admins from creating/editing: + - public objects + - objects with no groups + - objects that belong to a group they don't curate + """ + if not user or user.role == UserRole.ADMIN: + return + + if object_is_public: + detail = "User does not have permission to create public credentials" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + if not target_group_ids: + detail = "Curators must specify 1+ groups" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + user_curated_groups = fetch_user_groups_for_user( + db_session=db_session, user_id=user.id, only_curator_groups=True + ) + user_curated_group_ids = set([group.id for group in user_curated_groups]) + target_group_ids_set = set(target_group_ids) + if not target_group_ids_set.issubset(user_curated_group_ids): + detail = "Curators cannot control groups they don't curate" + logger.error(detail) + raise HTTPException( + status_code=400, + detail=detail, + ) + + def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) return db_session.scalar(stmt) def fetch_user_groups( - db_session: Session, only_current: bool = True + db_session: Session, only_up_to_date: bool = True ) -> Sequence[UserGroup]: + """ + Fetches user groups from the database. + + This function retrieves a sequence of `UserGroup` objects from the database. + If `only_up_to_date` is set to `True`, it filters the user groups to return only those + that are marked as up-to-date (`is_up_to_date` is `True`). + + Args: + db_session (Session): The SQLAlchemy session used to query the database. + only_up_to_date (bool, optional): Flag to determine whether to filter the results + to include only up to date user groups. Defaults to `True`. + + Returns: + Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria. + """ stmt = select(UserGroup) - if only_current: + if only_up_to_date: stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712 return db_session.scalars(stmt).all() @@ -58,6 +119,42 @@ def fetch_user_groups_for_user( return db_session.scalars(stmt).all() +def construct_document_select_by_usergroup( + user_group_id: int, +) -> Select: + """This returns a statement that should be executed using + .yield_per() to minimize overhead. The primary consumers of this function + are background processing task generators.""" + stmt = ( + select(Document) + .join( + DocumentByConnectorCredentialPair, + Document.id == DocumentByConnectorCredentialPair.id, + ) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .join( + UserGroup__ConnectorCredentialPair, + UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, + ) + .join( + UserGroup, + UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, + ) + .where(UserGroup.id == user_group_id) + .order_by(Document.id) + ) + stmt = stmt.distinct() + return stmt + + def fetch_documents_for_user_group_paginated( db_session: Session, user_group_id: int, @@ -316,6 +413,10 @@ def update_user_group( user_group_id: int, user_group_update: UserGroupUpdate, ) -> UserGroup: + """If successful, this can set db_user_group.is_up_to_date = False. + That will be processed by check_for_vespa_user_groups_sync_task and trigger + a long running background sync to Vespa. + """ stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) if db_user_group is None: diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 736296517db..8590fd6c5e7 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -1,14 +1,24 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response +from fastapi import status from fastapi import UploadFile from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.auth.users import get_user_manager +from danswer.auth.users import UserManager from danswer.db.engine import get_session from danswer.db.models import User from danswer.file_store.file_store import get_default_file_store +from danswer.utils.logger import setup_logger from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME @@ -18,10 +28,117 @@ from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import store_settings from ee.danswer.server.enterprise_settings.store import upload_logo +from shared_configs.configs import CUSTOM_REFRESH_URL admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") +logger = setup_logger() + + +def mocked_refresh_token() -> dict: + """ + This function mocks the response from a token refresh endpoint. + It generates a mock access token, refresh token, and user information + with an expiration time set to 1 hour from now. + This is useful for testing or development when the actual refresh endpoint is not available. + """ + mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000) + data = { + "access_token": "asdf Mock access token", + "refresh_token": "asdf Mock refresh token", + "session": {"exp": mock_exp}, + "userinfo": { + "sub": "Mock email", + "familyName": "Mock name", + "givenName": "Mock name", + "fullName": "Mock name", + "userId": "Mock User ID", + "email": "test_email@danswer.ai", + }, + } + return data + + +@basic_router.get("/refresh-token") +async def refresh_access_token( + user: User = Depends(current_user), + user_manager: UserManager = Depends(get_user_manager), +) -> None: + # return + if CUSTOM_REFRESH_URL is None: + logger.error( + "Custom refresh URL is not set and client is attempting to custom refresh" + ) + raise HTTPException( + status_code=500, + detail="Custom refresh URL is not set", + ) + + try: + async with httpx.AsyncClient() as client: + logger.debug(f"Sending request to custom refresh URL for user {user.id}") + access_token = user.oauth_accounts[0].access_token + + response = await client.get( + CUSTOM_REFRESH_URL, + params={"info": "json", "access_token_refresh_interval": 3600}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + data = response.json() + + # NOTE: Here is where we can mock the response + # data = mocked_refresh_token() + + logger.debug(f"Received response from Meechum auth URL for user {user.id}") + + # Extract new tokens + new_access_token = data["access_token"] + new_refresh_token = data["refresh_token"] + + new_expiry = datetime.fromtimestamp( + data["session"]["exp"] / 1000, tz=timezone.utc + ) + expires_at_timestamp = int(new_expiry.timestamp()) + + logger.debug(f"Access token has been refreshed for user {user.id}") + + await user_manager.oauth_callback( + oauth_name="custom", + access_token=new_access_token, + account_id=data["userinfo"]["userId"], + account_email=data["userinfo"]["email"], + expires_at=expires_at_timestamp, + refresh_token=new_refresh_token, + associate_by_email=True, + ) + + logger.info(f"Successfully refreshed tokens for user {user.id}") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + logger.warning(f"Full authentication required for user {user.id}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Full authentication required", + ) + logger.error( + f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to refresh token", + ) + except Exception as e: + logger.error( + f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred", + ) + @admin_router.put("") def put_settings( diff --git a/backend/ee/danswer/server/enterprise_settings/models.py b/backend/ee/danswer/server/enterprise_settings/models.py index c9831d87aeb..c770fbd73e7 100644 --- a/backend/ee/danswer/server/enterprise_settings/models.py +++ b/backend/ee/danswer/server/enterprise_settings/models.py @@ -1,4 +1,13 @@ +from typing import List + from pydantic import BaseModel +from pydantic import Field + + +class NavigationItem(BaseModel): + link: str + icon: str + title: str class EnterpriseSettings(BaseModel): @@ -10,11 +19,16 @@ class EnterpriseSettings(BaseModel): use_custom_logo: bool = False use_custom_logotype: bool = False + # custom navigation + custom_nav_items: List[NavigationItem] = Field(default_factory=list) + # custom Chat components + two_lines_for_chat_header: bool | None = None custom_lower_disclaimer_content: str | None = None custom_header_content: str | None = None custom_popup_header: str | None = None custom_popup_content: str | None = None + enable_consent_screen: bool | None = None def check_validity(self) -> None: return diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index 0d5d1987f34..55561982325 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -7,7 +7,10 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.models import AllCitations from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import FinalUsedContextDocsResponse +from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError @@ -41,7 +44,7 @@ router = APIRouter(prefix="/chat") -def translate_doc_response_to_simple_doc( +def _translate_doc_response_to_simple_doc( doc_response: QADocsResponse, ) -> list[SimpleDoc]: return [ @@ -60,6 +63,23 @@ def translate_doc_response_to_simple_doc( ] +def _get_final_context_doc_indices( + final_context_docs: list[LlmDoc] | None, + simple_search_docs: list[SimpleDoc] | None, +) -> list[int] | None: + """ + this function returns a list of indices of the simple search docs + that were actually fed to the LLM. + """ + if final_context_docs is None or simple_search_docs is None: + return None + + final_context_doc_ids = {doc.document_id for doc in final_context_docs} + return [ + i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids + ] + + def remove_answer_citations(answer: str) -> str: pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)" @@ -120,17 +140,29 @@ def handle_simplified_chat_message( ) response = ChatBasicResponse() + final_context_docs: list[LlmDoc] = [] answer = "" for packet in packets: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): - response.simple_search_docs = translate_doc_response_to_simple_doc(packet) + response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.message_id = packet.message_id + elif isinstance(packet, FinalUsedContextDocsResponse): + final_context_docs = packet.final_context_docs + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } + + response.final_context_doc_indices = _get_final_context_doc_indices( + final_context_docs, response.simple_search_docs + ) response.answer = answer if answer: @@ -152,6 +184,8 @@ def handle_send_message_simple_with_history( if len(req.messages) == 0: raise HTTPException(status_code=400, detail="Messages cannot be zero length") + # This is a sanity check to make sure the chat history is valid + # It must start with a user message and alternate between user and assistant expected_role = MessageType.USER for msg in req.messages: if not msg.message: @@ -225,14 +259,22 @@ def handle_send_message_simple_with_history( history_str=history_str, ) + if req.retrieval_options is None and req.search_doc_ids is None: + retrieval_options: RetrievalDetails | None = RetrievalDetails( + run_search=OptionalSearchSetting.ALWAYS, + real_time=False, + ) + else: + retrieval_options = req.retrieval_options + full_chat_msg_info = CreateChatMessageRequest( chat_session_id=chat_session.id, parent_message_id=chat_message.id, message=query, file_descriptors=[], prompt_id=req.prompt_id, - search_doc_ids=None, - retrieval_options=req.retrieval_options, + search_doc_ids=req.search_doc_ids, + retrieval_options=retrieval_options, query_override=rephrased_query, chunks_above=0, chunks_below=0, @@ -246,19 +288,31 @@ def handle_send_message_simple_with_history( ) response = ChatBasicResponse() + final_context_docs: list[LlmDoc] = [] answer = "" for packet in packets: if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): - response.simple_search_docs = translate_doc_response_to_simple_doc(packet) + response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.message_id = packet.message_id elif isinstance(packet, LLMRelevanceFilterResponse): - response.llm_chunks_indices = packet.relevant_chunk_indices + response.llm_selected_doc_indices = packet.llm_selected_doc_indices + elif isinstance(packet, FinalUsedContextDocsResponse): + final_context_docs = packet.final_context_docs + elif isinstance(packet, AllCitations): + response.cited_documents = { + citation.citation_num: citation.document_id + for citation in packet.citations + } + + response.final_context_doc_indices = _get_final_context_doc_indices( + final_context_docs, response.simple_search_docs + ) response.answer = answer if answer: diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index b0ce553ebe0..b1ea648c8f0 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -52,9 +52,11 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext): messages: list[ThreadMessage] prompt_id: int | None persona_id: int - retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) + retrieval_options: RetrievalDetails | None = None query_override: str | None = None skip_rerank: bool | None = None + # If search_doc_ids provided, then retrieval options are unused + search_doc_ids: list[int] | None = None class SimpleDoc(BaseModel): @@ -74,4 +76,7 @@ class ChatBasicResponse(BaseModel): simple_search_docs: list[SimpleDoc] | None = None error_msg: str | None = None message_id: int | None = None - llm_chunks_indices: list[int] | None = None + llm_selected_doc_indices: list[int] | None = None + final_context_doc_indices: list[int] | None = None + # this is a map of the citation number to the document id + cited_documents: dict[int, str] | None = None diff --git a/backend/ee/danswer/server/saml.py b/backend/ee/danswer/server/saml.py index 5bc62e98d61..38966c15756 100644 --- a/backend/ee/danswer/server/saml.py +++ b/backend/ee/danswer/server/saml.py @@ -65,6 +65,7 @@ async def upsert_saml_user(email: str) -> User: password=hashed_pass, is_verified=True, role=role, + has_web_login=True, ) ) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index bbca5acc20a..10dc1afb972 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -51,10 +51,12 @@ def _seed_llms( if llm_upsert_requests: logger.notice("Seeding LLMs") seeded_providers = [ - upsert_llm_provider(db_session, llm_upsert_request) + upsert_llm_provider(llm_upsert_request, db_session) for llm_upsert_request in llm_upsert_requests ] - update_default_provider(db_session, seeded_providers[0].id) + update_default_provider( + provider_id=seeded_providers[0].id, db_session=db_session + ) def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py index e18487d5491..355e59fff1d 100644 --- a/backend/ee/danswer/server/user_group/api.py +++ b/backend/ee/danswer/server/user_group/api.py @@ -9,6 +9,7 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.models import UserRole +from danswer.utils.logger import setup_logger from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.db.user_group import fetch_user_groups_for_user from ee.danswer.db.user_group import insert_user_group @@ -20,6 +21,8 @@ from ee.danswer.server.user_group.models import UserGroupCreate from ee.danswer.server.user_group.models import UserGroupUpdate +logger = setup_logger() + router = APIRouter(prefix="/manage") @@ -29,7 +32,7 @@ def list_user_groups( db_session: Session = Depends(get_session), ) -> list[UserGroup]: if user is None or user.role == UserRole.ADMIN: - user_groups = fetch_user_groups(db_session, only_current=False) + user_groups = fetch_user_groups(db_session, only_up_to_date=False) else: user_groups = fetch_user_groups_for_user( db_session=db_session, @@ -90,6 +93,7 @@ def set_user_curator( set_curator_request=set_curator_request, ) except ValueError as e: + logger.error(f"Error setting user curator: {e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/ee/danswer/user_groups/sync.py b/backend/ee/danswer/user_groups/sync.py deleted file mode 100644 index e3bea192670..00000000000 --- a/backend/ee/danswer/user_groups/sync.py +++ /dev/null @@ -1,87 +0,0 @@ -from sqlalchemy.orm import Session - -from danswer.access.access import get_access_for_documents -from danswer.db.document import prepare_to_modify_documents -from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings -from danswer.document_index.factory import get_default_document_index -from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.interfaces import UpdateRequest -from danswer.utils.logger import setup_logger -from ee.danswer.db.user_group import delete_user_group -from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated -from ee.danswer.db.user_group import fetch_user_group -from ee.danswer.db.user_group import mark_user_group_as_synced - -logger = setup_logger() - -_SYNC_BATCH_SIZE = 100 - - -def _sync_user_group_batch( - document_ids: list[str], document_index: DocumentIndex, db_session: Session -) -> None: - logger.debug(f"Syncing document sets for: {document_ids}") - - # Acquires a lock on the documents so that no other process can modify them - with prepare_to_modify_documents(db_session=db_session, document_ids=document_ids): - # get current state of document sets for these documents - document_id_to_access = get_access_for_documents( - document_ids=document_ids, db_session=db_session - ) - - # update Vespa - document_index.update( - update_requests=[ - UpdateRequest( - document_ids=[document_id], - access=document_id_to_access[document_id], - ) - for document_id in document_ids - ] - ) - - # Finish the transaction and release the locks - db_session.commit() - - -def sync_user_groups(user_group_id: int, db_session: Session) -> None: - """Sync the status of Postgres for the specified user group""" - search_settings = get_current_search_settings(db_session) - secondary_search_settings = get_secondary_search_settings(db_session) - - document_index = get_default_document_index( - primary_index_name=search_settings.index_name, - secondary_index_name=secondary_search_settings.index_name - if secondary_search_settings - else None, - ) - - user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id) - if user_group is None: - raise ValueError(f"User group '{user_group_id}' does not exist") - - cursor = None - while True: - # NOTE: this may miss some documents, but that is okay. Any new documents added - # will be added with the correct group membership - document_batch, cursor = fetch_documents_for_user_group_paginated( - db_session=db_session, - user_group_id=user_group_id, - last_document_id=cursor, - limit=_SYNC_BATCH_SIZE, - ) - - _sync_user_group_batch( - document_ids=[document.id for document in document_batch], - document_index=document_index, - db_session=db_session, - ) - - if cursor is None: - break - - if user_group.is_up_for_deletion: - delete_user_group(db_session=db_session, user_group=user_group) - else: - mark_user_group_as_synced(db_session=db_session, user_group=user_group) diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index 38bf4b077fa..fde3c8d0dc9 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -3,15 +3,21 @@ from fastapi import APIRouter from huggingface_hub import snapshot_download # type: ignore from transformers import AutoTokenizer # type: ignore -from transformers import BatchEncoding +from transformers import BatchEncoding # type: ignore +from transformers import PreTrainedTokenizer # type: ignore from danswer.utils.logger import setup_logger from model_server.constants import MODEL_WARM_UP_STRING +from model_server.danswer_torch_model import ConnectorClassifier from model_server.danswer_torch_model import HybridClassifier from model_server.utils import simple_log_function_time +from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO +from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import INTENT_MODEL_TAG from shared_configs.configs import INTENT_MODEL_VERSION +from shared_configs.model_server_models import ConnectorClassificationRequest +from shared_configs.model_server_models import ConnectorClassificationResponse from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse @@ -19,10 +25,55 @@ router = APIRouter(prefix="/custom") +_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None +_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None + _INTENT_TOKENIZER: AutoTokenizer | None = None _INTENT_MODEL: HybridClassifier | None = None +def get_connector_classifier_tokenizer() -> AutoTokenizer: + global _CONNECTOR_CLASSIFIER_TOKENIZER + if _CONNECTOR_CLASSIFIER_TOKENIZER is None: + # The tokenizer details are not uploaded to the HF hub since it's just the + # unmodified distilbert tokenizer. + _CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained( + "distilbert-base-uncased" + ) + return _CONNECTOR_CLASSIFIER_TOKENIZER + + +def get_local_connector_classifier( + model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO, + tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG, +) -> ConnectorClassifier: + global _CONNECTOR_CLASSIFIER_MODEL + if _CONNECTOR_CLASSIFIER_MODEL is None: + try: + # Calculate where the cache should be, then load from local if available + local_path = snapshot_download( + repo_id=model_name_or_path, revision=tag, local_files_only=True + ) + _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( + local_path + ) + except Exception as e: + logger.warning(f"Failed to load model directly: {e}") + try: + # Attempt to download the model snapshot + logger.info(f"Downloading model snapshot for {model_name_or_path}") + local_path = snapshot_download(repo_id=model_name_or_path, revision=tag) + _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( + local_path + ) + except Exception as e: + logger.error( + f"Failed to load model even after attempted snapshot download: {e}" + ) + raise + return _CONNECTOR_CLASSIFIER_MODEL + + def get_intent_model_tokenizer() -> AutoTokenizer: global _INTENT_TOKENIZER if _INTENT_TOKENIZER is None: @@ -61,6 +112,74 @@ def get_local_intent_model( return _INTENT_MODEL +def tokenize_connector_classification_query( + connectors: list[str], + query: str, + tokenizer: PreTrainedTokenizer, + connector_token_end_id: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models + + The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end + token and then the user query. + """ + + input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long) + + for connector in connectors: + connector_token_ids = tokenizer( + connector, + add_special_tokens=False, + return_tensors="pt", + ) + + input_ids = torch.cat( + ( + input_ids, + connector_token_ids["input_ids"].squeeze(dim=0), + torch.tensor([connector_token_end_id], dtype=torch.long), + ), + dim=-1, + ) + query_token_ids = tokenizer( + query, + add_special_tokens=False, + return_tensors="pt", + ) + + input_ids = torch.cat( + ( + input_ids, + query_token_ids["input_ids"].squeeze(dim=0), + torch.tensor([tokenizer.sep_token_id], dtype=torch.long), + ), + dim=-1, + ) + attention_mask = torch.ones(input_ids.numel(), dtype=torch.long) + + return input_ids.unsqueeze(0), attention_mask.unsqueeze(0) + + +def warm_up_connector_classifier_model() -> None: + logger.info( + f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}" + ) + connector_classifier_tokenizer = get_connector_classifier_tokenizer() + connector_classifier = get_local_connector_classifier() + + input_ids, attention_mask = tokenize_connector_classification_query( + ["GitHub"], + "danswer classifier query google doc", + connector_classifier_tokenizer, + connector_classifier.connector_end_token_id, + ) + input_ids = input_ids.to(connector_classifier.device) + attention_mask = attention_mask.to(connector_classifier.device) + + connector_classifier(input_ids, attention_mask) + + def warm_up_intent_model() -> None: logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}") intent_tokenizer = get_intent_model_tokenizer() @@ -157,6 +276,35 @@ def clean_keywords(keywords: list[str]) -> list[str]: return cleaned_words +def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]: + tokenizer = get_connector_classifier_tokenizer() + model = get_local_connector_classifier() + + connector_names = req.available_connectors + + input_ids, attention_mask = tokenize_connector_classification_query( + connector_names, + req.query, + tokenizer, + model.connector_end_token_id, + ) + input_ids = input_ids.to(model.device) + attention_mask = attention_mask.to(model.device) + + global_confidence, classifier_confidence = model(input_ids, attention_mask) + + if global_confidence.item() < 0.5: + return [] + + passed_connectors = [] + + for i, connector_name in enumerate(connector_names): + if classifier_confidence.view(-1)[i].item() > 0.5: + passed_connectors.append(connector_name) + + return passed_connectors + + def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]: tokenizer = get_intent_model_tokenizer() model_input = tokenizer( @@ -189,6 +337,22 @@ def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]: return is_keyword_sequence, cleaned_keywords +@router.post("/connector-classification") +async def process_connector_classification_request( + classification_request: ConnectorClassificationRequest, +) -> ConnectorClassificationResponse: + if INDEXING_ONLY: + raise RuntimeError( + "Indexing model server should not call connector classification endpoint" + ) + + if len(classification_request.available_connectors) == 0: + return ConnectorClassificationResponse(connectors=[]) + + connectors = run_connector_classification(classification_request) + return ConnectorClassificationResponse(connectors=connectors) + + @router.post("/query-analysis") async def process_analysis_request( intent_request: IntentRequest, diff --git a/backend/model_server/danswer_torch_model.py b/backend/model_server/danswer_torch_model.py index 28554a4fd2d..7390a97e049 100644 --- a/backend/model_server/danswer_torch_model.py +++ b/backend/model_server/danswer_torch_model.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn from transformers import DistilBertConfig # type: ignore -from transformers import DistilBertModel +from transformers import DistilBertModel # type: ignore +from transformers import DistilBertTokenizer # type: ignore class HybridClassifier(nn.Module): @@ -21,7 +22,6 @@ def __init__(self) -> None: self.distilbert.config.dim, self.distilbert.config.dim ) self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2) - self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout) self.device = torch.device("cpu") @@ -36,8 +36,7 @@ def forward( # Intent classification on the CLS token cls_token_state = sequence_output[:, 0, :] pre_classifier_out = self.pre_classifier(cls_token_state) - dropout_out = self.dropout(pre_classifier_out) - intent_logits = self.intent_classifier(dropout_out) + intent_logits = self.intent_classifier(pre_classifier_out) # Keyword classification on all tokens token_logits = self.keyword_classifier(sequence_output) @@ -72,3 +71,70 @@ def from_pretrained(cls, load_directory: str) -> "HybridClassifier": param.requires_grad = False return model + + +class ConnectorClassifier(nn.Module): + def __init__(self, config: DistilBertConfig) -> None: + super().__init__() + + self.config = config + self.distilbert = DistilBertModel(config) + self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1) + self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1) + self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + + # Token indicating end of connector name, and on which classifier is used + self.connector_end_token_id = self.tokenizer.get_vocab()[ + self.config.connector_end_token + ] + + self.device = torch.device("cpu") + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.distilbert( + input_ids=input_ids, attention_mask=attention_mask + ).last_hidden_state + + cls_hidden_states = hidden_states[ + :, 0, : + ] # Take leap of faith that first token is always [CLS] + global_logits = self.connector_global_classifier(cls_hidden_states).view(-1) + global_confidence = torch.sigmoid(global_logits).view(-1) + + connector_end_position_ids = input_ids == self.connector_end_token_id + connector_end_hidden_states = hidden_states[connector_end_position_ids] + classifier_output = self.connector_match_classifier(connector_end_hidden_states) + classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1) + + return global_confidence, classifier_confidence + + @classmethod + def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier": + config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")) + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("mps") + if torch.backends.mps.is_available() + else torch.device("cpu") + ) + state_dict = torch.load( + os.path.join(repo_dir, "pytorch_model.pt"), + map_location=device, + weights_only=True, + ) + + model = cls(config) + model.load_state_dict(state_dict) + model.to(device) + model.device = device + model.eval() + + for param in model.parameters(): + param.requires_grad = False + + return model diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 4e97bd00f27..860151b3dc4 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -2,6 +2,7 @@ from typing import Any from typing import Optional +import httpx import openai import vertexai # type: ignore import voyageai # type: ignore @@ -83,7 +84,7 @@ def __init__( self.client = _initialize_client(api_key, self.provider, model) def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_OPENAI_MODEL # OpenAI does not seem to provide truncation option, however @@ -110,7 +111,7 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_COHERE_MODEL final_embeddings: list[Embedding] = [] @@ -129,7 +130,7 @@ def _embed_cohere( def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_VOYAGE_MODEL # Similar to Cohere, the API server will do approximate size chunking @@ -145,7 +146,7 @@ def _embed_voyage( def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: - if model is None: + if not model: model = DEFAULT_VERTEX_MODEL embeddings = self.client.get_embeddings( @@ -171,7 +172,6 @@ def embed( try: if self.provider == EmbeddingProvider.OPENAI: return self._embed_openai(texts, model_name) - embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: return self._embed_cohere(texts, model_name, embedding_type) @@ -235,6 +235,25 @@ def get_local_reranking_model( return _RERANK_MODEL +def embed_with_litellm_proxy( + texts: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[Embedding]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] + + @simple_log_function_time() def embed_text( texts: list[str], @@ -245,21 +264,42 @@ def embed_text( api_key: str | None, provider_type: EmbeddingProvider | None, prefix: str | None, + api_url: str | None, ) -> list[Embedding]: + logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}") + if not all(texts): + logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") - # Third party API based embedding model if not texts: + logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") + + if provider_type == EmbeddingProvider.LITELLM: + logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}") + if not api_url: + logger.error("API URL not provided for LiteLLM proxy") + raise ValueError("API URL is required for LiteLLM proxy embedding.") + try: + return embed_with_litellm_proxy( + texts=texts, + api_url=api_url, + model_name=model_name or "", + api_key=api_key, + ) + except Exception as e: + logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") + raise + elif provider_type is not None: - logger.debug(f"Embedding text with provider: {provider_type}") + logger.debug(f"Using cloud provider {provider_type} for embedding") if api_key is None: + logger.error("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model") if prefix: - # This may change in the future if some providers require the user - # to manually append a prefix but this is not the case currently + logger.warning("Prefix provided for cloud model, which is not supported") raise ValueError( "Prefix string is not valid for cloud models. " "Cloud models take an explicit text type instead." @@ -274,14 +314,15 @@ def embed_text( text_type=text_type, ) - # Check for None values in embeddings if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" error_message += "Corresponding texts:\n" error_message += "\n".join(texts) + logger.error(error_message) raise ValueError(error_message) elif model_name is not None: + logger.debug(f"Using local model {model_name} for embedding") prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( @@ -296,10 +337,12 @@ def embed_text( ] else: + logger.error("Neither model name nor provider specified for embedding") raise ValueError( "Either model name or provider must be provided to run embeddings." ) + logger.info(f"Successfully embedded {len(texts)} texts") return embeddings @@ -319,6 +362,28 @@ def cohere_rerank( return [result.relevance_score for result in sorted_results] +def litellm_rerank( + query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None +) -> list[float]: + headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model_name, + "query": query, + "documents": docs, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [ + item["relevance_score"] + for item in sorted(result["results"], key=lambda x: x["index"]) + ] + + @router.post("/bi-encoder-embed") async def process_embed_request( embed_request: EmbedRequest, @@ -344,6 +409,7 @@ async def process_embed_request( api_key=embed_request.api_key, provider_type=embed_request.provider_type, text_type=embed_request.text_type, + api_url=embed_request.api_url, prefix=prefix, ) return EmbedResponse(embeddings=embeddings) @@ -374,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons model_name=rerank_request.model_name, ) return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.LITELLM: + if rerank_request.api_url is None: + raise ValueError("API URL is required for LiteLLM reranking.") + + sim_scores = litellm_rerank( + query=rerank_request.query, + docs=rerank_request.documents, + api_url=rerank_request.api_url, + model_name=rerank_request.model_name, + api_key=rerank_request.api_key, + ) + + return RerankResponse(scores=sim_scores) + elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 37e603f9b46..82a1ee320c9 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -1,10 +1,10 @@ -aiohttp==3.9.4 +aiohttp==3.10.2 alembic==1.10.4 asyncpg==0.27.0 atlassian-python-api==3.37.0 beautifulsoup4==4.12.2 boto3==1.34.84 -celery[redis]==5.3.4 +celery==5.3.4 boto3==1.34.84 chardet==5.2.0 dask==2023.8.1 @@ -14,7 +14,7 @@ fastapi==0.109.2 fastapi-health==0.4.0 fastapi-users==12.1.3 fastapi-users-db-sqlalchemy==5.0.0 -filelock==3.12.0 +filelock==3.15.4 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 @@ -28,13 +28,12 @@ huggingface-hub==0.20.1 jira==3.5.1 jsonref==1.1.0 langchain==0.1.17 -langchain-community==0.0.36 langchain-core==0.1.50 langchain-text-splitters==0.0.1 litellm==1.43.18 llama-index==0.9.45 Mako==1.2.4 -msal==1.26.0 +msal==1.28.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 @@ -52,10 +51,11 @@ python-pptx==0.6.23 pypdf==3.17.0 pytest-mock==3.12.0 pytest-playwright==0.3.2 -python-docx==1.1.0 +python-docx==1.1.2 python-dotenv==1.0.0 python-multipart==0.0.7 pywikibot==9.0.0 +redis==5.0.8 requests==2.32.2 requests-oauthlib==1.3.1 retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 0fb0e74b67b..18c2cefed28 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -8,7 +8,7 @@ pydantic==2.8.2 retry==0.9.2 safetensors==0.4.2 sentence-transformers==2.6.1 -torch==2.0.1 +torch==2.2.0 transformers==4.39.2 uvicorn==0.21.1 voyageai==0.2.3 diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py index 118a4dfa4b4..0a9857304c8 100755 --- a/backend/scripts/force_delete_connector_by_id.py +++ b/backend/scripts/force_delete_connector_by_id.py @@ -83,8 +83,7 @@ def _unsafe_deletion( # Delete index attempts delete_index_attempts( db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, + cc_pair_id=cc_pair.id, ) # Delete document sets diff --git a/backend/scripts/restart_containers.sh b/backend/scripts/restart_containers.sh index c60d1905eb5..838df5b5c79 100755 --- a/backend/scripts/restart_containers.sh +++ b/backend/scripts/restart_containers.sh @@ -1,15 +1,16 @@ #!/bin/bash # Usage of the script with optional volume arguments -# ./restart_containers.sh [vespa_volume] [postgres_volume] +# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume] VESPA_VOLUME=${1:-""} # Default is empty if not provided POSTGRES_VOLUME=${2:-""} # Default is empty if not provided +REDIS_VOLUME=${3:-""} # Default is empty if not provided # Stop and remove the existing containers echo "Stopping and removing existing containers..." -docker stop danswer_postgres danswer_vespa -docker rm danswer_postgres danswer_vespa +docker stop danswer_postgres danswer_vespa danswer_redis +docker rm danswer_postgres danswer_vespa danswer_redis # Start the PostgreSQL container with optional volume echo "Starting PostgreSQL container..." @@ -27,6 +28,14 @@ else docker run --detach --name danswer_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8 fi +# Start the Redis container with optional volume +echo "Starting Redis container..." +if [[ -n "$REDIS_VOLUME" ]]; then + docker run --detach --name danswer_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis +else + docker run --detach --name danswer_redis --publish 6379:6379 redis +fi + # Ensure alembic runs in the correct directory SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" PARENT_DIR="$(dirname "$SCRIPT_DIR")" diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5ad36cc93c4..fe933227009 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -16,9 +16,12 @@ ) # Danswer custom Deep Learning Models +CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model" +CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0" INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier" INTENT_MODEL_TAG = "v1.0.3" + # Bi-Encoder, other details DOC_EMBEDDING_CONTEXT_SIZE = 512 @@ -58,9 +61,11 @@ # Fields which should only be set on new search setting PRESERVED_SEARCH_FIELDS = [ + "id", "provider_type", "api_key", "model_name", + "api_url", "index_name", "multipass_indexing", "model_dim", @@ -68,3 +73,5 @@ "passage_prefix", "query_prefix", ] + +CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token" diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 918872d44b3..b58ac0a8928 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -6,10 +6,12 @@ class EmbeddingProvider(str, Enum): COHERE = "cohere" VOYAGE = "voyage" GOOGLE = "google" + LITELLM = "litellm" class RerankerProvider(str, Enum): COHERE = "cohere" + LITELLM = "litellm" class EmbedTextType(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 3014616c620..dd846ed6bad 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -7,6 +7,15 @@ Embedding = list[float] +class ConnectorClassificationRequest(BaseModel): + available_connectors: list[str] + query: str + + +class ConnectorClassificationResponse(BaseModel): + connectors: list[str] + + class EmbedRequest(BaseModel): texts: list[str] # Can be none for cloud embedding model requests, error handling logic exists for other cases @@ -18,6 +27,7 @@ class EmbedRequest(BaseModel): text_type: EmbedTextType manual_query_prefix: str | None = None manual_passage_prefix: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @@ -33,6 +43,7 @@ class RerankRequest(BaseModel): model_name: str provider_type: RerankerProvider | None = None api_key: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/supervisord.conf b/backend/supervisord.conf index b56c763b94f..697866b6c0a 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -24,14 +24,21 @@ autorestart=true # relatively compute-light (e.g. they tend to just make a bunch of requests to # Vespa / Postgres) [program:celery_worker] -command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=6 --loglevel=INFO --logfile=/var/log/celery_worker_supervisor.log +command=celery -A danswer.background.celery.celery_run:celery_app worker + --pool=threads + --concurrency=6 + --loglevel=INFO + --logfile=/var/log/celery_worker_supervisor.log + -Q celery,vespa_metadata_sync environment=LOG_FILE_NAME=celery_worker redirect_stderr=true autorestart=true # Job scheduler for periodic tasks [program:celery_beat] -command=celery -A danswer.background.celery.celery_run:celery_app beat --loglevel=INFO --logfile=/var/log/celery_beat_supervisor.log +command=celery -A danswer.background.celery.celery_run:celery_app beat + --loglevel=INFO + --logfile=/var/log/celery_beat_supervisor.log environment=LOG_FILE_NAME=celery_beat redirect_stderr=true autorestart=true diff --git a/backend/tests/api/test_api.py b/backend/tests/api/test_api.py index 059c40824d5..9a3571ef585 100644 --- a/backend/tests/api/test_api.py +++ b/backend/tests/api/test_api.py @@ -101,4 +101,4 @@ def test_handle_send_message_simple_with_history(client: TestClient) -> None: resp_json = response.json() # persona must have LLM relevance enabled for this to pass - assert len(resp_json["llm_chunks_indices"]) > 0 + assert len(resp_json["llm_selected_doc_indices"]) > 0 diff --git a/backend/tests/daily/connectors/confluence/test_confluence_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_basic.py index 7f05242c50b..4eb25207814 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_basic.py @@ -8,7 +8,13 @@ @pytest.fixture def confluence_connector() -> ConfluenceConnector: - connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"]) + connector = ConfluenceConnector( + wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"], + space=os.environ["CONFLUENCE_TEST_SPACE"], + is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true", + page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), + ) + connector.load_credentials( { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index a9c12b236cf..b736f374741 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("OPENAI_API_KEY"), provider_type=EmbeddingProvider.OPENAI, + api_url=None, ) @@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("COHERE_API_KEY"), provider_type=EmbeddingProvider.COHERE, + api_url=None, ) @@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel: passage_prefix="search_document: ", api_key=None, provider_type=None, + api_url=None, ) diff --git a/backend/tests/integration/README.md b/backend/tests/integration/README.md new file mode 100644 index 00000000000..bc5e388082f --- /dev/null +++ b/backend/tests/integration/README.md @@ -0,0 +1,70 @@ +# Integration Tests + +## General Testing Overview +The integration tests are designed with a "manager" class and a "test" class for each type of object being manipulated (e.g., user, persona, credential): +- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity. +- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object. + +The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function. + +## Instructions for Running Integration Tests Locally +1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080. + a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory. +2. Navigate to `danswer/backend`. +3. Run the following command in the terminal: + ```sh + pytest -s tests/integration/tests/ + ``` + or to run all tests in a file: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py + ``` + or to run a single test: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py::test_function_name + ``` + +## Guidelines for Writing Integration Tests +- As authentication is currently required for all tests, each test should start by creating a user. +- Each test should ideally focus on a single API flow. +- The test writer should try to consider failure cases and edge cases for the flow and write the tests to check for these cases. +- Every step of the test should be commented describing what is being done and what the expected behavior is. +- A summary of the test should be given at the top of the test function as well! +- When writing new tests, manager classes, manager functions, and test classes, try to copy the style of the other ones that have already been written. +- Be careful for scope creep! + - No need to overcomplicate every test by verifying after every single API call so long as the case you would be verifying is covered elsewhere (ideally in a test focused on covering that case). + - An example of this is: Creating an admin user is done at the beginning of nearly every test, but we only need to verify that the user is actually an admin in the test focused on checking admin permissions. For every other test, we can just create the admin user and assume that the permissions are working as expected. + +## Current Testing Limitations +### Test coverage +- All tests are probably not as high coverage as they could be. +- The "connector" tests in particular are super bare bones because we will be reworking connector/cc_pair sometime soon. +- Global Curator role is not thoroughly tested. +- No auth is not tested at all. +### Failure checking +- While we test expected auth failures, we only check that it failed at all. +- We dont check that the return codes are what we expect. +- This means that a test could be failing for a different reason than expected. +- We should ensure that the proper codes are being returned for each failure case. +- We should also query the db after each failure to ensure that the db is in the expected state. +### Scope/focus +- The tests may be scoped sub-optimally. +- The scoping of each test may be overlapping. + +## Current Testing Coverage +The current testing coverage should be checked by reading the comments at the top of each test file. + + +## TODO: Testing Coverage +- Persona permissions testing +- Read only (and/or basic) user permissions + - Ensuring proper permission enforcement using the chat/doc_search endpoints +- No auth + +## Ideas for integration testing design +### Combine the "test" and "manager" classes +This could make test writing a bit cleaner by preventing test writers from having to pass around objects into functions that the objects have a 1:1 relationship with. + +### Rework VespaClient +Right now, its used a fixture and has to be passed around between manager classes. +Could just be built where its used diff --git a/backend/tests/integration/common_utils/connectors.py b/backend/tests/integration/common_utils/connectors.py deleted file mode 100644 index e7734cec3c8..00000000000 --- a/backend/tests/integration/common_utils/connectors.py +++ /dev/null @@ -1,114 +0,0 @@ -import uuid -from typing import cast - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from danswer.db.enums import ConnectorCredentialPairStatus -from tests.integration.common_utils.constants import API_SERVER_URL - - -class ConnectorCreationDetails(BaseModel): - connector_id: int - credential_id: int - cc_pair_id: int - - -class ConnectorClient: - @staticmethod - def create_connector( - name_prefix: str = "test_connector", credential_id: int | None = None - ) -> ConnectorCreationDetails: - unique_id = uuid.uuid4() - - connector_name = f"{name_prefix}_{unique_id}" - connector_data = { - "name": connector_name, - "source": DocumentSource.NOT_APPLICABLE, - "input_type": "load_state", - "connector_specific_config": {}, - "refresh_freq": 60, - "disabled": True, - } - response = requests.post( - f"{API_SERVER_URL}/manage/admin/connector", - json=connector_data, - ) - response.raise_for_status() - connector_id = response.json()["id"] - - # associate the credential with the connector - if not credential_id: - print("ID not specified, creating new credential") - # Create a new credential - credential_data = { - "credential_json": {}, - "admin_public": True, - "source": DocumentSource.NOT_APPLICABLE, - } - response = requests.post( - f"{API_SERVER_URL}/manage/credential", - json=credential_data, - ) - response.raise_for_status() - credential_id = cast(int, response.json()["id"]) - - cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True} - response = requests.put( - f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", - json=cc_pair_metadata, - ) - response.raise_for_status() - - # fetch the conenector credential pair id using the indexing status API - response = requests.get( - f"{API_SERVER_URL}/manage/admin/connector/indexing-status" - ) - response.raise_for_status() - indexing_statuses = response.json() - - cc_pair_id = None - for status in indexing_statuses: - if ( - status["connector"]["id"] == connector_id - and status["credential"]["id"] == credential_id - ): - cc_pair_id = status["cc_pair_id"] - break - - if cc_pair_id is None: - raise ValueError("Could not find the connector credential pair id") - - print( - f"Created connector with connector_id: {connector_id}, credential_id: {credential_id}, cc_pair_id: {cc_pair_id}" - ) - return ConnectorCreationDetails( - connector_id=int(connector_id), - credential_id=int(credential_id), - cc_pair_id=int(cc_pair_id), - ) - - @staticmethod - def update_connector_status( - cc_pair_id: int, status: ConnectorCredentialPairStatus - ) -> None: - response = requests.put( - f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/status", - json={"status": status}, - ) - response.raise_for_status() - - @staticmethod - def delete_connector(connector_id: int, credential_id: int) -> None: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/deletion-attempt", - json={"connector_id": connector_id, "credential_id": credential_id}, - ) - response.raise_for_status() - - @staticmethod - def get_connectors() -> list[dict]: - response = requests.get(f"{API_SERVER_URL}/manage/connector") - response.raise_for_status() - return response.json() diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index efc98dde7de..7d729191cf6 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -5,3 +5,7 @@ API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}" MAX_DELAY = 30 + +GENERAL_HEADERS = {"Content-Type": "application/json"} + +NUM_DOCS = 5 diff --git a/backend/tests/integration/common_utils/document_sets.py b/backend/tests/integration/common_utils/document_sets.py deleted file mode 100644 index dc898611108..00000000000 --- a/backend/tests/integration/common_utils/document_sets.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import cast - -import requests - -from danswer.server.features.document_set.models import DocumentSet -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.constants import API_SERVER_URL - - -class DocumentSetClient: - @staticmethod - def create_document_set( - doc_set_creation_request: DocumentSetCreationRequest, - ) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/document-set", - json=doc_set_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()) - - @staticmethod - def fetch_document_sets() -> list[DocumentSet]: - response = requests.get(f"{API_SERVER_URL}/manage/document-set") - response.raise_for_status() - - document_sets = [ - DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json() - ] - return document_sets diff --git a/backend/tests/integration/common_utils/llm.py b/backend/tests/integration/common_utils/llm.py index ba8b89d6b4d..f74b40073c9 100644 --- a/backend/tests/integration/common_utils/llm.py +++ b/backend/tests/integration/common_utils/llm.py @@ -1,62 +1,88 @@ import os -from typing import cast +from uuid import uuid4 import requests -from pydantic import BaseModel -from pydantic import PrivateAttr from danswer.server.manage.llm.models import LLMProviderUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestLLMProvider +from tests.integration.common_utils.test_models import TestUser -class LLMProvider(BaseModel): - provider: str - api_key: str - default_model_name: str - api_base: str | None = None - api_version: str | None = None - is_default: bool = True +class LLMProviderManager: + @staticmethod + def create( + name: str | None = None, + provider: str | None = None, + api_key: str | None = None, + default_model_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + groups: list[int] | None = None, + is_public: bool | None = None, + user_performing_action: TestUser | None = None, + ) -> TestLLMProvider: + print("Seeding LLM Providers...") - # only populated after creation - _provider_id: int | None = PrivateAttr() - - def create(self) -> int: llm_provider = LLMProviderUpsertRequest( - name=self.provider, - provider=self.provider, - default_model_name=self.default_model_name, - api_key=self.api_key, - api_base=self.api_base, - api_version=self.api_version, + name=name or f"test-provider-{uuid4()}", + provider=provider or "openai", + default_model_name=default_model_name or "gpt-4o-mini", + api_key=api_key or os.environ["OPENAI_API_KEY"], + api_base=api_base, + api_version=api_version, custom_config=None, - fast_default_model_name=None, - is_public=True, - groups=[], + fast_default_model_name=default_model_name or "gpt-4o-mini", + is_public=is_public or True, + groups=groups or [], display_model_names=None, model_names=None, ) - response = requests.put( + llm_response = requests.put( f"{API_SERVER_URL}/admin/llm/provider", - json=llm_provider.dict(), + json=llm_provider.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + llm_response.raise_for_status() + response_data = llm_response.json() + result_llm = TestLLMProvider( + id=response_data["id"], + name=response_data["name"], + provider=response_data["provider"], + api_key=response_data["api_key"], + default_model_name=response_data["default_model_name"], + is_public=response_data["is_public"], + groups=response_data["groups"], + api_base=response_data["api_base"], + api_version=response_data["api_version"], ) - response.raise_for_status() - self._provider_id = cast(int, response.json()["id"]) - return self._provider_id + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + set_default_response.raise_for_status() - def delete(self) -> None: + return result_llm + + @staticmethod + def delete( + llm_provider: TestLLMProvider, + user_performing_action: TestUser | None = None, + ) -> bool: + if not llm_provider.id: + raise ValueError("LLM Provider ID is required to delete a provider") response = requests.delete( - f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}" + f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, ) response.raise_for_status() - - -def seed_default_openai_provider() -> LLMProvider: - llm = LLMProvider( - provider="openai", - default_model_name="gpt-4o-mini", - api_key=os.environ["OPENAI_API_KEY"], - ) - llm.create() - return llm + return True diff --git a/backend/tests/integration/common_utils/managers/api_key.py b/backend/tests/integration/common_utils/managers/api_key.py new file mode 100644 index 00000000000..b6d2c29b732 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/api_key.py @@ -0,0 +1,92 @@ +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from ee.danswer.server.api_key.models import APIKeyArgs +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser + + +class APIKeyManager: + @staticmethod + def create( + name: str | None = None, + api_key_role: UserRole = UserRole.ADMIN, + user_performing_action: TestUser | None = None, + ) -> TestAPIKey: + name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}" + api_key_request = APIKeyArgs( + name=name, + role=api_key_role, + ) + api_key_response = requests.post( + f"{API_SERVER_URL}/admin/api-key", + json=api_key_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + api_key = api_key_response.json() + result_api_key = TestAPIKey( + api_key_id=api_key["api_key_id"], + api_key_display=api_key["api_key_display"], + api_key=api_key["api_key"], + api_key_name=name, + api_key_role=api_key_role, + user_id=api_key["user_id"], + headers=GENERAL_HEADERS, + ) + result_api_key.headers["Authorization"] = f"Bearer {result_api_key.api_key}" + return result_api_key + + @staticmethod + def delete( + api_key: TestAPIKey, + user_performing_action: TestUser | None = None, + ) -> None: + api_key_response = requests.delete( + f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestAPIKey]: + api_key_response = requests.get( + f"{API_SERVER_URL}/admin/api-key", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + return [TestAPIKey(**api_key) for api_key in api_key_response.json()] + + @staticmethod + def verify( + api_key: TestAPIKey, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + retrieved_keys = APIKeyManager.get_all( + user_performing_action=user_performing_action + ) + for key in retrieved_keys: + if key.api_key_id == api_key.api_key_id: + if verify_deleted: + raise ValueError("API Key found when it should have been deleted") + if ( + key.api_key_name == api_key.api_key_name + and key.api_key_role == api_key.api_key_role + ): + return + + if not verify_deleted: + raise Exception("API Key not found") diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py new file mode 100644 index 00000000000..6498252bbe8 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -0,0 +1,202 @@ +import time +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.server.documents.models import ConnectorCredentialPairIdentifier +from danswer.server.documents.models import ConnectorIndexingStatus +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser + + +def _cc_pair_creator( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, +) -> TestCCPair: + name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}" + + request = { + "name": name, + "is_public": is_public, + "groups": groups or [], + } + + response = requests.put( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return TestCCPair( + id=response.json()["data"], + name=name, + connector_id=connector_id, + credential_id=credential_id, + is_public=is_public, + groups=groups or [], + ) + + +class CCPairManager: + @staticmethod + def create_from_scratch( + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + credential_json: dict[str, Any] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + connector = ConnectorManager.create( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + credential = CredentialManager.create( + credential_json=credential_json, + name=name, + source=source, + curator_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + return _cc_pair_creator( + connector_id=connector.id, + credential_id=credential.id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def create( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + return _cc_pair_creator( + connector_id=connector_id, + credential_id=credential_id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def pause_cc_pair( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + result = requests.put( + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status", + json={"status": "PAUSED"}, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def delete( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + cc_pair_identifier = ConnectorCredentialPairIdentifier( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ) + result = requests.post( + url=f"{API_SERVER_URL}/manage/admin/deletion-attempt", + json=cc_pair_identifier.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[ConnectorIndexingStatus]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/connector/indexing-status", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()] + + @staticmethod + def verify( + cc_pair: TestCCPair, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_cc_pairs = CCPairManager.get_all(user_performing_action) + for retrieved_cc_pair in all_cc_pairs: + if retrieved_cc_pair.cc_pair_id == cc_pair.id: + if verify_deleted: + # We assume that this check will be performed after the deletion is + # already waited for + raise ValueError( + f"CC pair {cc_pair.id} found but should be deleted" + ) + if ( + retrieved_cc_pair.name == cc_pair.name + and retrieved_cc_pair.connector.id == cc_pair.connector_id + and retrieved_cc_pair.credential.id == cc_pair.credential_id + and retrieved_cc_pair.public_doc == cc_pair.is_public + and set(retrieved_cc_pair.groups) == set(cc_pair.groups) + ): + return + + if not verify_deleted: + raise ValueError(f"CC pair {cc_pair.id} not found") + + @staticmethod + def wait_for_deletion_completion( + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + cc_pairs = CCPairManager.get_all(user_performing_action) + if all( + cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING + for cc_pair in cc_pairs + ): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"CC pairs deletion was not completed within the {MAX_DELAY} seconds" + ) + else: + print("Some CC pairs are still being deleted, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/managers/connector.py b/backend/tests/integration/common_utils/managers/connector.py new file mode 100644 index 00000000000..f72d079683b --- /dev/null +++ b/backend/tests/integration/common_utils/managers/connector.py @@ -0,0 +1,124 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.server.documents.models import ConnectorUpdateRequest +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestConnector +from tests.integration.common_utils.test_models import TestUser + + +class ConnectorManager: + @staticmethod + def create( + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestConnector: + name = f"{name}-connector" if name else f"test-connector-{uuid4()}" + + connector_update_request = ConnectorUpdateRequest( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + is_public=is_public, + groups=groups or [], + ) + + response = requests.post( + url=f"{API_SERVER_URL}/manage/admin/connector", + json=connector_update_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + response_data = response.json() + return TestConnector( + id=response_data.get("id"), + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + groups=groups, + is_public=is_public, + ) + + @staticmethod + def edit( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.patch( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + json=connector.model_dump(exclude={"id"}), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestConnector]: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) + for conn in response.json() + ] + + @staticmethod + def get( + connector_id: int, user_performing_action: TestUser | None = None + ) -> TestConnector: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + conn = response.json() + return TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) diff --git a/backend/tests/integration/common_utils/managers/credential.py b/backend/tests/integration/common_utils/managers/credential.py new file mode 100644 index 00000000000..c05cd1b5a3e --- /dev/null +++ b/backend/tests/integration/common_utils/managers/credential.py @@ -0,0 +1,129 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.server.documents.models import CredentialSnapshot +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestCredential +from tests.integration.common_utils.test_models import TestUser + + +class CredentialManager: + @staticmethod + def create( + credential_json: dict[str, Any] | None = None, + admin_public: bool = True, + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + curator_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCredential: + name = f"{name}-credential" if name else f"test-credential-{uuid4()}" + + credential_request = { + "name": name, + "credential_json": credential_json or {}, + "admin_public": admin_public, + "source": source, + "curator_public": curator_public, + "groups": groups or [], + } + response = requests.post( + url=f"{API_SERVER_URL}/manage/credential", + json=credential_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + + response.raise_for_status() + return TestCredential( + id=response.json()["id"], + name=name, + credential_json=credential_json or {}, + admin_public=admin_public, + source=source, + curator_public=curator_public, + groups=groups or [], + ) + + @staticmethod + def edit( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + request = credential.model_dump(include={"name", "credential_json"}) + response = requests.put( + url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/credential/{credential.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get( + credential_id: int, user_performing_action: TestUser | None = None + ) -> CredentialSnapshot: + response = requests.get( + url=f"{API_SERVER_URL}/manage/credential/{credential_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return CredentialSnapshot(**response.json()) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[CredentialSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/manage/credential", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [CredentialSnapshot(**cred) for cred in response.json()] + + @staticmethod + def verify( + credential: TestCredential, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_credentials = CredentialManager.get_all(user_performing_action) + for fetched_credential in all_credentials: + if credential.id == fetched_credential.id: + if verify_deleted: + raise ValueError( + f"Credential {credential.id} found but should be deleted" + ) + if ( + credential.name == fetched_credential.name + and credential.admin_public == fetched_credential.admin_public + and credential.source == fetched_credential.source + and credential.curator_public == fetched_credential.curator_public + ): + return + if not verify_deleted: + raise ValueError(f"Credential {credential.id} not found") diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py new file mode 100644 index 00000000000..3f691eca8f9 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document.py @@ -0,0 +1,153 @@ +from uuid import uuid4 + +import requests + +from danswer.configs.constants import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import TestAPIKey +from tests.integration.common_utils.managers.cc_pair import TestCCPair +from tests.integration.common_utils.test_models import SimpleTestDocument +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.vespa import TestVespaClient + + +def _verify_document_permissions( + retrieved_doc: dict, + cc_pair: TestCCPair, + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, +) -> None: + acl_keys = set(retrieved_doc["access_control_list"].keys()) + print(f"ACL keys: {acl_keys}") + if cc_pair.is_public: + if "PUBLIC" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} is public but" + " does not have the PUBLIC ACL key" + ) + + if doc_creating_user is not None: + if f"user_id:{doc_creating_user.id}" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} was created by user" + f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key" + ) + + if group_names is not None: + expected_group_keys = {f"group:{group_name}" for group_name in group_names} + found_group_keys = {key for key in acl_keys if key.startswith("group:")} + if found_group_keys != expected_group_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} has incorrect group ACL keys. Found: {found_group_keys}, \n" + f"Expected: {expected_group_keys}" + ) + + if doc_set_names is not None: + found_doc_set_names = set(retrieved_doc.get("document_sets", {}).keys()) + if found_doc_set_names != set(doc_set_names): + raise ValueError( + f"Document set names mismatch. \nFound: {found_doc_set_names}, \n" + f"Expected: {set(doc_set_names)}" + ) + + +def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict: + return { + "document": { + "id": document_id, + "sections": [ + { + "text": f"This is test document {document_id}", + "link": f"{document_id}", + } + ], + "source": DocumentSource.NOT_APPLICABLE, + # just for testing metadata + "metadata": {"document_id": document_id}, + "semantic_identifier": f"Test Document {document_id}", + "from_ingestion_api": True, + }, + "cc_pair_id": cc_pair_id, + } + + +class DocumentManager: + @staticmethod + def seed_and_attach_docs( + cc_pair: TestCCPair, + num_docs: int = NUM_DOCS, + document_ids: list[str] | None = None, + api_key: TestAPIKey | None = None, + ) -> TestCCPair: + # Use provided document_ids if available, otherwise generate random UUIDs + if document_ids is None: + document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)] + else: + num_docs = len(document_ids) + # Create and ingest some documents + documents: list[dict] = [] + for document_id in document_ids: + document = _generate_dummy_document(document_id, cc_pair.id) + documents.append(document) + response = requests.post( + f"{API_SERVER_URL}/danswer-api/ingestion", + json=document, + headers=api_key.headers if api_key else GENERAL_HEADERS, + ) + response.raise_for_status() + + print("Seeding completed successfully.") + cc_pair.documents = [ + SimpleTestDocument( + id=document["document"]["id"], + content=document["document"]["sections"][0]["text"], + ) + for document in documents + ] + return cc_pair + + @staticmethod + def verify( + vespa_client: TestVespaClient, + cc_pair: TestCCPair, + # If None, will not check doc sets or groups + # If empty list, will check for empty doc sets or groups + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, + verify_deleted: bool = False, + ) -> None: + doc_ids = [document.id for document in cc_pair.documents] + retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"] + retrieved_docs = { + doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict + } + # Left this here for debugging purposes. + # import json + # for doc in retrieved_docs.values(): + # printable_doc = doc.copy() + # print(printable_doc.keys()) + # printable_doc.pop("embeddings") + # printable_doc.pop("title_embedding") + # print(json.dumps(printable_doc, indent=2)) + + for document in cc_pair.documents: + retrieved_doc = retrieved_docs.get(document.id) + if not retrieved_doc: + if not verify_deleted: + raise ValueError(f"Document not found: {document.id}") + continue + if verify_deleted: + raise ValueError( + f"Document found when it should be deleted: {document.id}" + ) + _verify_document_permissions( + retrieved_doc, + cc_pair, + doc_set_names, + group_names, + doc_creating_user, + ) diff --git a/backend/tests/integration/common_utils/managers/document_set.py b/backend/tests/integration/common_utils/managers/document_set.py new file mode 100644 index 00000000000..8133ccc8712 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document_set.py @@ -0,0 +1,171 @@ +import time +from uuid import uuid4 + +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestDocumentSet +from tests.integration.common_utils.test_models import TestUser + + +class DocumentSetManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + cc_pair_ids: list[int] | None = None, + is_public: bool = True, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestDocumentSet: + if name is None: + name = f"test_doc_set_{str(uuid4())}" + + doc_set_creation_request = { + "name": name, + "description": description or name, + "cc_pair_ids": cc_pair_ids or [], + "is_public": is_public, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + return TestDocumentSet( + id=int(response.json()), + name=name, + description=description or name, + cc_pair_ids=cc_pair_ids or [], + is_public=is_public, + is_up_to_date=True, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + doc_set_update_request = { + "id": document_set.id, + "description": document_set.description, + "cc_pair_ids": document_set.cc_pair_ids, + "is_public": document_set.is_public, + "users": document_set.users, + "groups": document_set.groups, + } + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def delete( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestDocumentSet]: + response = requests.get( + f"{API_SERVER_URL}/manage/document-set", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestDocumentSet( + id=doc_set["id"], + name=doc_set["name"], + description=doc_set["description"], + cc_pair_ids=[ + cc_pair["id"] for cc_pair in doc_set["cc_pair_descriptors"] + ], + is_public=doc_set["is_public"], + is_up_to_date=doc_set["is_up_to_date"], + users=doc_set["users"], + groups=doc_set["groups"], + ) + for doc_set in response.json() + ] + + @staticmethod + def wait_for_sync( + document_sets_to_check: list[TestDocumentSet] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + # wait for document sets to be synced + start = time.time() + while True: + doc_sets = DocumentSetManager.get_all(user_performing_action) + if document_sets_to_check: + check_ids = {doc_set.id for doc_set in document_sets_to_check} + doc_set_ids = {doc_set.id for doc_set in doc_sets} + if not check_ids.issubset(doc_set_ids): + raise RuntimeError("Document set not found") + doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids] + all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets) + + if all_up_to_date: + break + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"Document sets were not synced within the {MAX_DELAY} seconds" + ) + else: + print("Document sets were not synced yet, waiting...") + + time.sleep(2) + + @staticmethod + def verify( + document_set: TestDocumentSet, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + doc_sets = DocumentSetManager.get_all(user_performing_action) + for doc_set in doc_sets: + if doc_set.id == document_set.id: + if verify_deleted: + raise ValueError( + f"Document set {document_set.id} found but should have been deleted" + ) + if ( + doc_set.name == document_set.name + and set(doc_set.cc_pair_ids) == set(document_set.cc_pair_ids) + and doc_set.is_public == document_set.is_public + and set(doc_set.users) == set(document_set.users) + and set(doc_set.groups) == set(document_set.groups) + ): + return + if not verify_deleted: + raise ValueError(f"Document set {document_set.id} not found") diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py new file mode 100644 index 00000000000..41ff43edb6f --- /dev/null +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -0,0 +1,206 @@ +from uuid import uuid4 + +import requests + +from danswer.search.enums import RecencyBiasSetting +from danswer.server.features.persona.models import PersonaSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestPersona +from tests.integration.common_utils.test_models import TestUser + + +class PersonaManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + num_chunks: float = 5, + llm_relevance_filter: bool = True, + is_public: bool = True, + llm_filter_extraction: bool = True, + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + name = name or f"test-persona-{uuid4()}" + description = description or f"Description for {name}" + + persona_creation_request = { + "name": name, + "description": description, + "num_chunks": num_chunks, + "llm_relevance_filter": llm_relevance_filter, + "is_public": is_public, + "llm_filter_extraction": llm_filter_extraction, + "recency_bias": recency_bias, + "prompt_ids": prompt_ids or [], + "document_set_ids": document_set_ids or [], + "tool_ids": tool_ids or [], + "llm_model_provider_override": llm_model_provider_override, + "llm_model_version_override": llm_model_version_override, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/persona", + json=persona_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + persona_data = response.json() + + return TestPersona( + id=persona_data["id"], + name=name, + description=description, + num_chunks=num_chunks, + llm_relevance_filter=llm_relevance_filter, + is_public=is_public, + llm_filter_extraction=llm_filter_extraction, + recency_bias=recency_bias, + prompt_ids=prompt_ids or [], + document_set_ids=document_set_ids or [], + tool_ids=tool_ids or [], + llm_model_provider_override=llm_model_provider_override, + llm_model_version_override=llm_model_version_override, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + persona: TestPersona, + name: str | None = None, + description: str | None = None, + num_chunks: float | None = None, + llm_relevance_filter: bool | None = None, + is_public: bool | None = None, + llm_filter_extraction: bool | None = None, + recency_bias: RecencyBiasSetting | None = None, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + persona_update_request = { + "name": name or persona.name, + "description": description or persona.description, + "num_chunks": num_chunks or persona.num_chunks, + "llm_relevance_filter": llm_relevance_filter + or persona.llm_relevance_filter, + "is_public": is_public or persona.is_public, + "llm_filter_extraction": llm_filter_extraction + or persona.llm_filter_extraction, + "recency_bias": recency_bias or persona.recency_bias, + "prompt_ids": prompt_ids or persona.prompt_ids, + "document_set_ids": document_set_ids or persona.document_set_ids, + "tool_ids": tool_ids or persona.tool_ids, + "llm_model_provider_override": llm_model_provider_override + or persona.llm_model_provider_override, + "llm_model_version_override": llm_model_version_override + or persona.llm_model_version_override, + "users": users or persona.users, + "groups": groups or persona.groups, + } + + response = requests.patch( + f"{API_SERVER_URL}/persona/{persona.id}", + json=persona_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + updated_persona_data = response.json() + + return TestPersona( + id=updated_persona_data["id"], + name=updated_persona_data["name"], + description=updated_persona_data["description"], + num_chunks=updated_persona_data["num_chunks"], + llm_relevance_filter=updated_persona_data["llm_relevance_filter"], + is_public=updated_persona_data["is_public"], + llm_filter_extraction=updated_persona_data["llm_filter_extraction"], + recency_bias=updated_persona_data["recency_bias"], + prompt_ids=updated_persona_data["prompts"], + document_set_ids=updated_persona_data["document_sets"], + tool_ids=updated_persona_data["tools"], + llm_model_provider_override=updated_persona_data[ + "llm_model_provider_override" + ], + llm_model_version_override=updated_persona_data[ + "llm_model_version_override" + ], + users=[user["email"] for user in updated_persona_data["users"]], + groups=updated_persona_data["groups"], + ) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[PersonaSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/admin/persona", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [PersonaSnapshot(**persona) for persona in response.json()] + + @staticmethod + def verify( + test_persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + all_personas = PersonaManager.get_all(user_performing_action) + for persona in all_personas: + if persona.id == test_persona.id: + return ( + persona.name == test_persona.name + and persona.description == test_persona.description + and persona.num_chunks == test_persona.num_chunks + and persona.llm_relevance_filter + == test_persona.llm_relevance_filter + and persona.is_public == test_persona.is_public + and persona.llm_filter_extraction + == test_persona.llm_filter_extraction + and persona.llm_model_provider_override + == test_persona.llm_model_provider_override + and persona.llm_model_version_override + == test_persona.llm_model_version_override + and set(persona.prompts) == set(test_persona.prompt_ids) + and set(persona.document_sets) == set(test_persona.document_set_ids) + and set(persona.tools) == set(test_persona.tool_ids) + and set(user.email for user in persona.users) + == set(test_persona.users) + and set(persona.groups) == set(test_persona.groups) + ) + return False + + @staticmethod + def delete( + persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/persona/{persona.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + return response.ok diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py new file mode 100644 index 00000000000..0946b8b1fca --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user.py @@ -0,0 +1,122 @@ +from copy import deepcopy +from urllib.parse import urlencode +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from danswer.server.manage.models import AllUsersResponse +from danswer.server.models import FullUserSnapshot +from danswer.server.models import InvitedUserSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestUser + + +class UserManager: + @staticmethod + def create( + name: str | None = None, + ) -> TestUser: + if name is None: + name = f"test{str(uuid4())}" + + email = f"{name}@test.com" + password = "test" + + body = { + "email": email, + "username": email, + "password": password, + } + response = requests.post( + url=f"{API_SERVER_URL}/auth/register", + json=body, + headers=GENERAL_HEADERS, + ) + response.raise_for_status() + + test_user = TestUser( + id=response.json()["id"], + email=email, + password=password, + headers=deepcopy(GENERAL_HEADERS), + ) + print(f"Created user {test_user.email}") + + test_user.headers["Cookie"] = UserManager.login_as_user(test_user) + + return test_user + + @staticmethod + def login_as_user(test_user: TestUser) -> str: + data = urlencode( + { + "username": test_user.email, + "password": test_user.password, + } + ) + headers = test_user.headers.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + response = requests.post( + url=f"{API_SERVER_URL}/auth/login", + data=data, + headers=headers, + ) + response.raise_for_status() + result_cookie = next(iter(response.cookies), None) + + if not result_cookie: + raise Exception("Failed to login") + + print(f"Logged in as {test_user.email}") + return f"{result_cookie.name}={result_cookie.value}" + + @staticmethod + def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool: + response = requests.get( + url=f"{API_SERVER_URL}/me", + headers=user_to_verify.headers, + ) + response.raise_for_status() + return target_role == UserRole(response.json().get("role", "")) + + @staticmethod + def set_role( + user_to_set: TestUser, + target_role: UserRole, + user_to_perform_action: TestUser | None = None, + ) -> None: + if user_to_perform_action is None: + user_to_perform_action = user_to_set + response = requests.patch( + url=f"{API_SERVER_URL}/manage/set-user-role", + json={"user_email": user_to_set.email, "new_role": target_role.value}, + headers=user_to_perform_action.headers, + ) + response.raise_for_status() + + @staticmethod + def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None: + if user_to_perform_action is None: + user_to_perform_action = user + response = requests.get( + url=f"{API_SERVER_URL}/manage/users", + headers=user_to_perform_action.headers + if user_to_perform_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + data = response.json() + all_users = AllUsersResponse( + accepted=[FullUserSnapshot(**user) for user in data["accepted"]], + invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + accepted_pages=data["accepted_pages"], + invited_pages=data["invited_pages"], + ) + for accepted_user in all_users.accepted: + if accepted_user.email == user.email and accepted_user.id == user.id: + return + raise ValueError(f"User {user.email} not found") diff --git a/backend/tests/integration/common_utils/managers/user_group.py b/backend/tests/integration/common_utils/managers/user_group.py new file mode 100644 index 00000000000..5f5ac6b0e30 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user_group.py @@ -0,0 +1,148 @@ +import time +from uuid import uuid4 + +import requests + +from ee.danswer.server.user_group.models import UserGroup +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup + + +class UserGroupManager: + @staticmethod + def create( + name: str | None = None, + user_ids: list[str] | None = None, + cc_pair_ids: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestUserGroup: + name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}" + + request = { + "name": name, + "user_ids": user_ids or [], + "cc_pair_ids": cc_pair_ids or [], + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + test_user_group = TestUserGroup( + id=response.json()["id"], + name=response.json()["name"], + user_ids=[user["id"] for user in response.json()["users"]], + cc_pair_ids=[cc_pair["id"] for cc_pair in response.json()["cc_pairs"]], + ) + return test_user_group + + @staticmethod + def edit( + user_group: TestUserGroup, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_group.id: + raise ValueError("User group has no ID") + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}", + json=user_group.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def set_curator_status( + test_user_group: TestUserGroup, + user_to_set_as_curator: TestUser, + is_curator: bool = True, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_to_set_as_curator.id: + raise ValueError("User has no ID") + set_curator_request = { + "user_id": user_to_set_as_curator.id, + "is_curator": is_curator, + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator", + json=set_curator_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[UserGroup]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/user-group", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [UserGroup(**ug) for ug in response.json()] + + @staticmethod + def verify( + user_group: TestUserGroup, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_user_groups = UserGroupManager.get_all(user_performing_action) + for fetched_user_group in all_user_groups: + if user_group.id == fetched_user_group.id: + if verify_deleted: + raise ValueError( + f"User group {user_group.id} found but should be deleted" + ) + fetched_cc_ids = {cc_pair.id for cc_pair in fetched_user_group.cc_pairs} + fetched_user_ids = {user.id for user in fetched_user_group.users} + user_group_cc_ids = set(user_group.cc_pair_ids) + user_group_user_ids = set(user_group.user_ids) + if ( + fetched_cc_ids == user_group_cc_ids + and fetched_user_ids == user_group_user_ids + ): + return + if not verify_deleted: + raise ValueError(f"User group {user_group.id} not found") + + @staticmethod + def wait_for_sync( + user_groups_to_check: list[TestUserGroup] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + user_groups = UserGroupManager.get_all(user_performing_action) + if user_groups_to_check: + check_ids = {user_group.id for user_group in user_groups_to_check} + user_group_ids = {user_group.id for user_group in user_groups} + if not check_ids.issubset(user_group_ids): + raise RuntimeError("Document set not found") + user_groups = [ + user_group + for user_group in user_groups + if user_group.id in check_ids + ] + if all(ug.is_up_to_date for ug in user_groups): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"User groups were not synced within the {MAX_DELAY} seconds" + ) + else: + print("User groups were not synced yet, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 3815aa9f972..a13ec184b45 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -20,7 +20,6 @@ from danswer.indexing.models import IndexingSetting from danswer.main import setup_postgres from danswer.main import setup_vespa -from tests.integration.common_utils.llm import seed_default_openai_provider def _run_migrations( @@ -32,6 +31,7 @@ def _run_migrations( # Create an Alembic configuration object alembic_cfg = Config("alembic.ini") alembic_cfg.set_section_option("logger_alembic", "level", "WARN") + alembic_cfg.attributes["configure_logger"] = False # Set the SQLAlchemy URL in the Alembic configuration alembic_cfg.set_main_option("sqlalchemy.url", database_url) @@ -131,11 +131,13 @@ def reset_vespa() -> None: search_settings = get_current_search_settings(db_session) index_name = search_settings.index_name - setup_vespa( + success = setup_vespa( document_index=VespaIndex(index_name=index_name, secondary_index_name=None), index_setting=IndexingSetting.from_db_model(search_settings), secondary_index_setting=None, ) + if not success: + raise RuntimeError("Could not connect to Vespa within the specified timeout.") for _ in range(5): try: @@ -167,6 +169,4 @@ def reset_all() -> None: reset_postgres() print("Resetting Vespa...") reset_vespa() - print("Seeding LLM Providers...") - seed_default_openai_provider() print("Finished resetting all.") diff --git a/backend/tests/integration/common_utils/seed_documents.py b/backend/tests/integration/common_utils/seed_documents.py deleted file mode 100644 index b6720c9aebe..00000000000 --- a/backend/tests/integration/common_utils/seed_documents.py +++ /dev/null @@ -1,72 +0,0 @@ -import uuid - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import API_SERVER_URL - - -class SimpleTestDocument(BaseModel): - id: str - content: str - - -class SeedDocumentResponse(BaseModel): - cc_pair_id: int - documents: list[SimpleTestDocument] - - -class TestDocumentClient: - @staticmethod - def seed_documents( - num_docs: int = 5, cc_pair_id: int | None = None - ) -> SeedDocumentResponse: - if not cc_pair_id: - connector_details = ConnectorClient.create_connector() - cc_pair_id = connector_details.cc_pair_id - - # Create and ingest some documents - documents: list[dict] = [] - for _ in range(num_docs): - document_id = f"test-doc-{uuid.uuid4()}" - document = { - "document": { - "id": document_id, - "sections": [ - { - "text": f"This is test document {document_id}", - "link": f"{document_id}", - } - ], - "source": DocumentSource.NOT_APPLICABLE, - # just for testing metadata - "metadata": {"document_id": document_id}, - "semantic_identifier": f"Test Document {document_id}", - "from_ingestion_api": True, - }, - "cc_pair_id": cc_pair_id, - } - documents.append(document) - response = requests.post( - f"{API_SERVER_URL}/danswer-api/ingestion", - json=document, - ) - response.raise_for_status() - - print("Seeding completed successfully.") - return SeedDocumentResponse( - cc_pair_id=cc_pair_id, - documents=[ - SimpleTestDocument( - id=document["document"]["id"], - content=document["document"]["sections"][0]["text"], - ) - for document in documents - ], - ) - - -if __name__ == "__main__": - seed_documents_resp = TestDocumentClient.seed_documents() diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py new file mode 100644 index 00000000000..04db0851e3d --- /dev/null +++ b/backend/tests/integration/common_utils/test_models.py @@ -0,0 +1,120 @@ +from typing import Any +from uuid import UUID + +from pydantic import BaseModel +from pydantic import Field + +from danswer.auth.schemas import UserRole +from danswer.search.enums import RecencyBiasSetting +from danswer.server.documents.models import DocumentSource +from danswer.server.documents.models import InputType + +""" +These data models are used to represent the data on the testing side of things. +This means the flow is: +1. Make request that changes data in db +2. Make a change to the testing model +3. Retrieve data from db +4. Compare db data with testing model to verify +""" + + +class TestAPIKey(BaseModel): + api_key_id: int + api_key_display: str + api_key: str | None = None # only present on initial creation + api_key_name: str | None = None + api_key_role: UserRole + + user_id: UUID + headers: dict + + +class TestUser(BaseModel): + id: str + email: str + password: str + headers: dict + + +class TestCredential(BaseModel): + id: int + name: str + credential_json: dict[str, Any] + admin_public: bool + source: DocumentSource + curator_public: bool + groups: list[int] + + +class TestConnector(BaseModel): + id: int + name: str + source: DocumentSource + input_type: InputType + connector_specific_config: dict[str, Any] + groups: list[int] | None = None + is_public: bool | None = None + + +class SimpleTestDocument(BaseModel): + id: str + content: str + + +class TestCCPair(BaseModel): + id: int + name: str + connector_id: int + credential_id: int + is_public: bool + groups: list[int] + documents: list[SimpleTestDocument] = Field(default_factory=list) + + +class TestUserGroup(BaseModel): + id: int + name: str + user_ids: list[str] + cc_pair_ids: list[int] + + +class TestLLMProvider(BaseModel): + id: int + name: str + provider: str + api_key: str + default_model_name: str + is_public: bool + groups: list[TestUserGroup] + api_base: str | None = None + api_version: str | None = None + + +class TestDocumentSet(BaseModel): + id: int + name: str + description: str + cc_pair_ids: list[int] = Field(default_factory=list) + is_public: bool + is_up_to_date: bool + users: list[str] = Field(default_factory=list) + groups: list[int] = Field(default_factory=list) + + +class TestPersona(BaseModel): + id: int + name: str + description: str + num_chunks: float + llm_relevance_filter: bool + is_public: bool + llm_filter_extraction: bool + recency_bias: RecencyBiasSetting + prompt_ids: list[int] + document_set_ids: list[int] + tool_ids: list[int] + llm_model_provider_override: str | None + llm_model_version_override: str | None + users: list[str] + groups: list[int] diff --git a/backend/tests/integration/common_utils/user_groups.py b/backend/tests/integration/common_utils/user_groups.py deleted file mode 100644 index 0cd44066463..00000000000 --- a/backend/tests/integration/common_utils/user_groups.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import cast - -import requests - -from ee.danswer.server.user_group.models import UserGroup -from ee.danswer.server.user_group.models import UserGroupCreate -from tests.integration.common_utils.constants import API_SERVER_URL - - -class UserGroupClient: - @staticmethod - def create_user_group(user_group_creation_request: UserGroupCreate) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/user-group", - json=user_group_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()["id"]) - - @staticmethod - def fetch_user_groups() -> list[UserGroup]: - response = requests.get(f"{API_SERVER_URL}/manage/admin/user-group") - response.raise_for_status() - return [UserGroup(**ug) for ug in response.json()] diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 6c46e9f875e..314b78ad36f 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,3 +1,4 @@ +import os from collections.abc import Generator import pytest @@ -9,6 +10,25 @@ from tests.integration.common_utils.vespa import TestVespaClient +def load_env_vars(env_file: str = ".env") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + env_path = os.path.join(current_dir, env_file) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + os.environ[key] = value.strip() + print("Successfully loaded environment variables") + except FileNotFoundError: + print(f"File {env_file} not found") + + +# Load environment variables at the module level +load_env_vars() + + @pytest.fixture def db_session() -> Generator[Session, None, None]: with get_session_context_manager() as session: diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py new file mode 100644 index 00000000000..e6f1b474170 --- /dev/null +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -0,0 +1,333 @@ +""" +This file contains tests for the following: +- Ensuring deletion of a connector also: + - deletes the documents in vespa for that connector + - updates the document sets and user groups to remove the connector +- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected +""" +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.enums import IndexingStatus +from danswer.db.index_attempt import create_index_attempt_error +from danswer.db.models import IndexAttempt +from danswer.db.search_settings import get_current_search_settings +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup +from tests.integration.common_utils.vespa import TestVespaClient + + +def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + # create document sets + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + name="Test Document Set 2", + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + + # wait for document sets to be synced + DocumentSetManager.wait_for_sync(user_performing_action=admin_user) + + print("Document sets created and synced") + + # create user groups + user_group_1: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + user_group_2: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync(user_performing_action=admin_user) + + # inject a finished index attempt and index attempt error (exercises foreign key errors) + with Session(get_sqlalchemy_engine()) as db_session: + primary_search_settings = get_current_search_settings(db_session) + new_attempt = IndexAttempt( + connector_credential_pair_id=cc_pair_1.id, + search_settings_id=primary_search_settings.id, + from_beginning=False, + status=IndexingStatus.COMPLETED_WITH_ERRORS, + ) + db_session.add(new_attempt) + db_session.commit() + + create_index_attempt_error( + index_attempt_id=new_attempt.id, + batch=1, + docs=[], + exception_msg="", + exception_traceback="", + db_session=db_session, + ) + + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + + # Update local records to match the database for later comparison + user_group_1.cc_pair_ids = [] + user_group_2.cc_pair_ids = [cc_pair_2.id] + doc_set_1.cc_pair_ids = [] + doc_set_2.cc_pair_ids = [cc_pair_2.id] + cc_pair_1.groups = [] + cc_pair_2.groups = [user_group_2.id] + + CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) + + # validate vespa documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + verify_deleted=True, + ) + + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[doc_set_2.name], + group_names=[user_group_2.name], + doc_creating_user=admin_user, + verify_deleted=False, + ) + + # check that only connector 1 is deleted + CCPairManager.verify( + cc_pair=cc_pair_2, + user_performing_action=admin_user, + ) + + # validate document sets + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) + + # validate user groups + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.verify( + user_group=user_group_2, + user_performing_action=admin_user, + ) + + +def test_connector_deletion_for_overlapping_connectors( + reset: None, vespa_client: TestVespaClient +) -> None: + """Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping + document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors. + """ + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + doc_ids = [str(uuid4())] + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + document_ids=doc_ids, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + document_ids=doc_ids, + api_key=api_key, + ) + + # verify vespa document exists and that it is not in any document sets or groups + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) + + # create document set + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + DocumentSetManager.wait_for_sync( + document_sets_to_check=[doc_set_1], + user_performing_action=admin_user, + ) + + print("Document set 1 created and synced") + + # verify vespa document is in the document set + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_creating_user=admin_user, + ) + + # create a user group and attach it to connector 1 + user_group_1: TestUserGroup = UserGroupManager.create( + name="Test User Group 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], + user_performing_action=admin_user, + ) + cc_pair_1.groups = [user_group_1.id] + + print("User group 1 created and synced") + + # create a user group and attach it to connector 2 + user_group_2: TestUserGroup = UserGroupManager.create( + name="Test User Group 2", + cc_pair_ids=[cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_2], + user_performing_action=admin_user, + ) + cc_pair_2.groups = [user_group_2.id] + + print("User group 2 created and synced") + + # verify vespa document is in the user group + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + + # EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE + + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + + # wait for deletion to finish + CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) + + print("Connector 1 deleted") + + # check that only connector 1 is deleted + # TODO: check for the CC pair rather than the connector once the refactor is done + CCPairManager.verify( + cc_pair=cc_pair_1, + verify_deleted=True, + user_performing_action=admin_user, + ) + CCPairManager.verify( + cc_pair=cc_pair_2, + user_performing_action=admin_user, + ) + + # verify the document is not in any document sets + # verify the document is only in user group 2 + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + group_names=[user_group_2.name], + doc_creating_user=admin_user, + verify_deleted=False, + ) diff --git a/backend/tests/integration/tests/connector/test_deletion.py b/backend/tests/integration/tests/connector/test_deletion.py deleted file mode 100644 index 78ad2378af9..00000000000 --- a/backend/tests/integration/tests/connector/test_deletion.py +++ /dev/null @@ -1,190 +0,0 @@ -import time - -from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import MAX_DELAY -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient -from tests.integration.common_utils.user_groups import UserGroupClient -from tests.integration.common_utils.user_groups import UserGroupCreate -from tests.integration.common_utils.vespa import TestVespaClient - - -def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: - # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c2_details = ConnectorClient.create_connector(name_prefix="tc2") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id - ) - c2_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c2_details.cc_pair_id - ) - - # create document sets - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="Intially connector to be deleted, should be empty after test", - cc_pair_ids=[c1_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) - ) - - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Intially both connectors, should contain undeleted connector after test", - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) - ) - - # wait for document sets to be synced - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) - - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") - - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("Document sets were not synced within the max delay") - - time.sleep(2) - - print("Document sets created and synced") - - # if so, create ACLs - user_group_1 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 1", user_ids=[], cc_pair_ids=[c1_details.cc_pair_id] - ) - ) - user_group_2 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 2", - user_ids=[], - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - ) - ) - - # wait for user groups to be available - start = time.time() - while True: - user_groups = {ug.id: ug for ug in UserGroupClient.fetch_user_groups()} - - if not ( - user_group_1 in user_groups.keys() and user_group_2 in user_groups.keys() - ): - raise RuntimeError("User groups not found") - - if ( - user_groups[user_group_1].is_up_to_date - and user_groups[user_group_2].is_up_to_date - ): - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("User groups were not synced within the max delay") - - time.sleep(2) - - print("User groups created and synced") - - # delete connector 1 - ConnectorClient.update_connector_status( - cc_pair_id=c1_details.cc_pair_id, status=ConnectorCredentialPairStatus.PAUSED - ) - ConnectorClient.delete_connector( - connector_id=c1_details.connector_id, credential_id=c1_details.credential_id - ) - - start = time.time() - while True: - connectors = ConnectorClient.get_connectors() - - if c1_details.connector_id not in [c["id"] for c in connectors]: - break - - if time.time() - start > MAX_DELAY: - raise TimeoutError("Connector 1 was not deleted within the max delay") - - time.sleep(2) - - print("Connector 1 deleted") - - # validate vespa documents - c1_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c1_seed_res.documents] - )["documents"] - c2_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c2_seed_res.documents] - )["documents"] - - assert len(c1_vespa_docs) == 0 - assert len(c2_vespa_docs) == 5 - - for doc in c2_vespa_docs: - assert doc["fields"]["access_control_list"] == { - "PUBLIC": 1, - "group:Test User Group 2": 1, - } - assert doc["fields"]["document_sets"] == {"Test Document Set 2": 1} - - # check that only connector 1 is deleted - # TODO: check for the CC pair rather than the connector once the refactor is done - all_connectors = ConnectorClient.get_connectors() - assert len(all_connectors) == 1 - assert all_connectors[0]["id"] == c2_details.connector_id - - # validate document sets - all_doc_sets = DocumentSetClient.fetch_document_sets() - assert len(all_doc_sets) == 2 - - doc_set_1_found = False - doc_set_2_found = False - for doc_set in all_doc_sets: - if doc_set.id == doc_set_1_id: - doc_set_1_found = True - assert doc_set.cc_pair_descriptors == [] - - if doc_set.id == doc_set_2_id: - doc_set_2_found = True - assert len(doc_set.cc_pair_descriptors) == 1 - assert doc_set.cc_pair_descriptors[0].id == c2_details.cc_pair_id - - assert doc_set_1_found - assert doc_set_2_found - - # validate user groups - all_user_groups = UserGroupClient.fetch_user_groups() - assert len(all_user_groups) == 2 - - user_group_1_found = False - user_group_2_found = False - for user_group in all_user_groups: - if user_group.id == user_group_1: - user_group_1_found = True - assert user_group.cc_pairs == [] - if user_group.id == user_group_2: - user_group_2_found = True - assert len(user_group.cc_pairs) == 1 - assert user_group.cc_pairs[0].id == c2_details.cc_pair_id - - assert user_group_1_found - assert user_group_2_found diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index b00c2e3d1e6..981a9cbd026 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -1,34 +1,59 @@ import requests -from tests.integration.common_utils.connectors import ConnectorClient +from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL -from tests.integration.common_utils.seed_documents import TestDocumentClient +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser def test_send_message_simple_with_history(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id + cc_pair_1: TestCCPair = CCPairManager.create_from_scratch( + user_performing_action=admin_user, + ) + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + LLMProviderManager.create(user_performing_action=admin_user) + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, ) response = requests.post( f"{API_SERVER_URL}/chat/send-message-simple-with-history", json={ - "messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}], + "messages": [ + { + "message": cc_pair_1.documents[0].content, + "role": MessageType.USER.value, + } + ], "persona_id": 0, "prompt_id": 0, }, + headers=admin_user.headers, ) assert response.status_code == 200 response_json = response.json() # Check that the top document is the correct document - assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id + assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id # assert that the metadata is correct - for doc in c1_seed_res.documents: + for doc in cc_pair_1.documents: found_doc = next( (x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None ) diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index 9a6b42ab5df..ab31b751471 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -1,78 +1,66 @@ -import time - -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser from tests.integration.common_utils.vespa import TestVespaClient def test_multiple_document_sets_syncing_same_connnector( reset: None, vespa_client: TestVespaClient ) -> None: - # Seed documents - seed_result = TestDocumentClient.seed_documents(num_docs=5) - cc_pair_id = seed_result.cc_pair_id + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") - # Create first document set - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="First test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, ) - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Second test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, ) - # wait for syncing to be complete - max_delay = 45 - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) - - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") - - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [ - ccp.id for ccp in doc_set_2.cc_pair_descriptors - ] - break + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) - if time.time() - start > max_delay: - raise TimeoutError("Document sets were not synced within the max delay") + # Create document sets + doc_set_1 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) - time.sleep(2) + DocumentSetManager.wait_for_sync( + user_performing_action=admin_user, + ) - # get names so we can compare to what is in vespa - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_names = {doc_set.name for doc_set in doc_sets} + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) # make sure documents are as expected - seeded_document_ids = [doc.id for doc in seed_result.documents] - - result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents]) - documents = result["documents"] - assert len(documents) == len(seed_result.documents) - assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents) - assert all( - set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name, doc_set_2.name], + doc_creating_user=admin_user, ) diff --git a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py new file mode 100644 index 00000000000..c52c5826eae --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py @@ -0,0 +1,179 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connector-credential pairs. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_cc_pair_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Create a credentials that the curator is and is not curator of + connector_1 = ConnectorManager.create( + name="curator_owned_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=admin_user, + ) + # currently we dont enforce permissions at the connector level + # pending cc_pair -> connector rework + # connector_2 = ConnectorManager.create( + # name="curator_visible_connector", + # source=DocumentSource.CONFLUENCE, + # groups=[user_group_2.id], + # is_public=False, + # user_performing_action=admin_user, + # ) + credential_1 = CredentialManager.create( + name="curator_owned_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=admin_user, + ) + credential_2 = CredentialManager.create( + name="curator_visible_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + curator_public=False, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public cc pair + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_1", + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair for a user group they are not a curator of + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair without an attached user group + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[], + is_public=False, + user_performing_action=curator, + ) + + # # This test is currently disabled because permissions are + # # not enforced at the connector level + # # Curators should not be able to create a cc pair + # # for a user group that the connector does not belong to (NOT WORKING) + # with pytest.raises(HTTPError): + # CCPairManager.create( + # connector_id=connector_2.id, + # credential_id=credential_1.id, + # name="invalid_cc_pair_3", + # groups=[user_group_1.id], + # is_public=False, + # user_performing_action=curator, + # ) + + # Curators should not be able to create a cc + # pair for a user group that the credential does not belong to + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_2.id, + name="invalid_cc_pair_4", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # cc pair for a user group they are a curator of + valid_cc_pair = CCPairManager.create( + name="valid_cc_pair", + connector_id=connector_1.id, + credential_id=credential_1.id, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + # Verify the created cc pair + CCPairManager.verify( + cc_pair=valid_cc_pair, + user_performing_action=curator, + ) + + # Test pausing the cc pair + CCPairManager.pause_cc_pair(valid_cc_pair, user_performing_action=curator) + + # Test deleting the cc pair + CCPairManager.delete(valid_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + CCPairManager.verify( + cc_pair=valid_cc_pair, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_connector_permissions.py b/backend/tests/integration/tests/permissions/test_connector_permissions.py new file mode 100644 index 00000000000..279c0568bfb --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_connector_permissions.py @@ -0,0 +1,136 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connectors. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_connector_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc pair for a + # user group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # connector for a user group they are a curator of + valid_connector = ConnectorManager.create( + name="valid_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + assert valid_connector.id is not None + + # Verify the created connector + created_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert created_connector.name == valid_connector.name + assert created_connector.source == valid_connector.source + + # Verify that the connector can be found in the list of all connectors + all_connectors = ConnectorManager.get_all(user_performing_action=curator) + assert any(conn.id == valid_connector.id for conn in all_connectors) + + # Test editing the connector + valid_connector.name = "updated_valid_connector" + ConnectorManager.edit(valid_connector, user_performing_action=curator) + + # Verify the edit + updated_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert updated_connector.name == "updated_valid_connector" + + # Test deleting the connector + ConnectorManager.delete(connector=valid_connector, user_performing_action=curator) + + # Verify the deletion + all_connectors_after_delete = ConnectorManager.get_all( + user_performing_action=curator + ) + assert all(conn.id != valid_connector.id for conn in all_connectors_after_delete) + + # Test that curator cannot create a connector for a group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_3", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Test that curator cannot create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_4", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_credential_permissions.py b/backend/tests/integration/tests/permissions/test_credential_permissions.py new file mode 100644 index 00000000000..1311f1a3d2d --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_credential_permissions.py @@ -0,0 +1,108 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating credentials. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_credential_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public credential + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a credential for a user group they are not a curator of + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + curator_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + # Curators should be able to create a private credential for a user group they are a curator of + valid_credential = CredentialManager.create( + name="valid_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=curator, + ) + + # Verify the created credential + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test editing the credential + valid_credential.name = "updated_valid_credential" + CredentialManager.edit(valid_credential, user_performing_action=curator) + + # Verify the edit + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test deleting the credential + CredentialManager.delete(valid_credential, user_performing_action=curator) + + # Verify the deletion + CredentialManager.verify( + credential=valid_credential, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py new file mode 100644 index 00000000000..412b5d41fad --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py @@ -0,0 +1,190 @@ +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_doc_set_permissions_setup(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a second user (curator) + curator: TestUser = UserManager.create(name="curator") + + # Creating the first user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Setting the curator as a curator for the first user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating a second user group + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Admin creates a cc_pair + private_cc_pair = CCPairManager.create_from_scratch( + is_public=False, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # Admin creates a public cc_pair + public_cc_pair = CCPairManager.create_from_scratch( + is_public=True, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators/Admins should not be able to do""" + + # Test that curator cannot create a document set for the group they don't curate + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 1", + groups=[user_group_2.id], + cc_pair_ids=[public_cc_pair.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set attached to both groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 2", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id, user_group_2.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 3", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test that admin cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=admin_user, + ) + + """Tests for things Curators should be able to do""" + # Test that curator can create a document set for the group they curate + valid_doc_set = DocumentSetManager.create( + name="Valid Document Set", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + DocumentSetManager.wait_for_sync( + document_sets_to_check=[valid_doc_set], user_performing_action=admin_user + ) + + # Verify that the valid document set was created + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify that only one document set exists + all_doc_sets = DocumentSetManager.get_all(user_performing_action=admin_user) + assert len(all_doc_sets) == 1 + + # Add the private_cc_pair to the doc set on our end for later comparison + valid_doc_set.cc_pair_ids.append(private_cc_pair.id) + + # Confirm the curator can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + # Confirm the admin can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify the document set has not been updated in the db + with pytest.raises(ValueError): + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Add the private_cc_pair to the user group on our end for later comparison + user_group_1.cc_pair_ids.append(private_cc_pair.id) + + # Admin adds the cc_pair to the group the curator curates + UserGroupManager.edit( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + + # Confirm the curator can now add the cc_pair to the doc set + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + DocumentSetManager.wait_for_sync( + document_sets_to_check=[valid_doc_set], user_performing_action=admin_user + ) + # Verify the updated document set + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_user_role_permissions.py b/backend/tests/integration/tests/permissions/test_user_role_permissions.py new file mode 100644 index 00000000000..5da91a57af8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_user_role_permissions.py @@ -0,0 +1,93 @@ +""" +This file tests the ability of different user types to set the role of other users. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.db.models import UserRole +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_user_role_setting_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a basic user + basic_user: TestUser = UserManager.create(name="basic_user") + assert UserManager.verify_role(basic_user, UserRole.BASIC) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + assert UserManager.verify_role(curator, UserRole.BASIC) + + # Creating a curator without adding to a group should not work + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=curator, + target_role=UserRole.CURATOR, + user_to_perform_action=admin_user, + ) + + global_curator: TestUser = UserManager.create(name="global_curator") + assert UserManager.verify_role(global_curator, UserRole.BASIC) + + # Setting the role of a global curator should not work for a basic user + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=basic_user, + ) + + # Setting the role of a global curator should work for an admin user + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=admin_user, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Setting the role of a global curator should not work for an invalid curator + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.BASIC, + user_to_perform_action=global_curator, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should fail because the curator is not in the user group + with pytest.raises(HTTPError): + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Adding the curator to the user group + user_group_1.user_ids = [curator.id] + UserGroupManager.edit(user_group=user_group_1, user_performing_action=admin_user) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should work because the curator is in the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py new file mode 100644 index 00000000000..878ba1e17e8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py @@ -0,0 +1,86 @@ +""" +This test tests the happy path for curator permissions +""" +from danswer.db.models import UserRole +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_whole_curator_flow(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # Making curator a curator of user_group_1 + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + assert UserManager.verify_role(curator, UserRole.CURATOR) + + # Creating a credential as curator + test_credential = CredentialManager.create( + name="curator_test_credential", + source=DocumentSource.FILE, + curator_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Creating a connector as curator + test_connector = ConnectorManager.create( + name="curator_test_connector", + source=DocumentSource.FILE, + is_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test editing the connector + test_connector.name = "updated_test_connector" + ConnectorManager.edit(connector=test_connector, user_performing_action=curator) + + # Creating a CC pair as curator + test_cc_pair = CCPairManager.create( + connector_id=test_connector.id, + credential_id=test_credential.id, + name="curator_test_cc_pair", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=admin_user) + + # Verify that the curator can pause and unpause the CC pair + CCPairManager.pause_cc_pair(cc_pair=test_cc_pair, user_performing_action=curator) + + # Verify that the curator can delete the CC pair + CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + # Verify that the CC pair has been deleted + CCPairManager.verify( + cc_pair=test_cc_pair, + verify_deleted=True, + user_performing_action=admin_user, + ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index d3974fe47ab..ad6baaa365a 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -12,6 +12,37 @@ from danswer.search.models import InferenceChunk +def test_passed_in_quotes() -> None: + # Test case 1: Basic quote separation + test_answer = """{ + "answer": "I can assist "James" with that", + "quotes": [ + "Danswer can just ingest PDFs as they are. How GOOD it embeds them depends on the formatting of your PDFs.", + "the ` danswer. llm ` package aims to provide a comprehensive framework." + ] + }""" + + answer, quotes = separate_answer_quotes(test_answer, is_json_prompt=True) + assert answer == 'I can assist "James" with that' + assert quotes == [ + "Danswer can just ingest PDFs as they are. How GOOD it embeds them depends on the formatting of your PDFs.", + "the ` danswer. llm ` package aims to provide a comprehensive framework.", + ] + + # Test case 2: Additional quotes + test_answer = """{ + "answer": "She said the resposne was "1" and I said the reponse was "2".", + "quotes": [ + "Danswer can efficiently ingest PDFs, with the quality of embedding depending on the PDF's formatting." + ] + }""" + answer, quotes = separate_answer_quotes(test_answer, is_json_prompt=True) + assert answer == 'She said the resposne was "1" and I said the reponse was "2".' + assert quotes == [ + "Danswer can efficiently ingest PDFs, with the quality of embedding depending on the PDF's formatting.", + ] + + def test_separate_answer_quotes() -> None: # Test case 1: Basic quote separation test_answer = textwrap.dedent( diff --git a/ct.yaml b/ct.yaml new file mode 100644 index 00000000000..764af160daf --- /dev/null +++ b/ct.yaml @@ -0,0 +1,12 @@ +# See https://github.com/helm/chart-testing#configuration + +chart-dirs: + - deployment/helm/charts + +chart-repos: + - vespa=https://unoplat.github.io/vespa-helm-charts + - postgresql=https://charts.bitnami.com/bitnami + +helm-extra-args: --timeout 900s + +validate-maintainers: false diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 636879497a2..945ec98c49b 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -1,4 +1,3 @@ -version: "3" services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always ports: @@ -35,13 +35,6 @@ services: - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-} # Gen AI Settings - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -69,8 +62,10 @@ services: # Other services - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose # Don't change the NLP model configs unless you know what you're doing + - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-} - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-} - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} @@ -91,12 +86,14 @@ services: - LOG_ENDPOINT_LATENCY=${LOG_ENDPOINT_LATENCY:-} - LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-} - LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-} + + # Chat Configs + - HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-} # Enterprise Edition only - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} # Seeding configuration - - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} extra_hosts: - "host.docker.internal:host-gateway" logging: @@ -114,19 +111,13 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} # Gen AI Settings (Needed by DanswerBot) - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -151,6 +142,7 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -162,6 +154,7 @@ services: - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} # Indexing Configs + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} - ENABLED_CONNECTOR_TYPES=${ENABLED_CONNECTOR_TYPES:-} - DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-} @@ -175,6 +168,7 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} + # Danswer SlackBot Configs - DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-} - DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-} @@ -186,6 +180,7 @@ services: - NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-} - DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-} - DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-} + - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} # Logging # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry @@ -234,6 +229,7 @@ services: # Enterprise Edition only - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} + - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} inference_model_server: image: danswer/danswer-model-server:${IMAGE_TAG:-latest} @@ -275,6 +271,7 @@ services: fi" restart: on-failure environment: + - INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - INDEXING_ONLY=True # Set to debug to get more fine-grained logs @@ -342,9 +339,19 @@ services: # in order to make this work on both Unix-like systems and windows command: > /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh - && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" - + && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.gpu-dev.yml b/deployment/docker_compose/docker-compose.gpu-dev.yml index 9079bd10dff..570e0a6ed7b 100644 --- a/deployment/docker_compose/docker-compose.gpu-dev.yml +++ b/deployment/docker_compose/docker-compose.gpu-dev.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always ports: @@ -32,13 +32,6 @@ services: - EMAIL_FROM=${EMAIL_FROM:-} - TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-} # Gen AI Settings - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -65,8 +58,10 @@ services: # Other services - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose # Don't change the NLP model configs unless you know what you're doing + - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-} - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-} - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} @@ -85,6 +80,9 @@ services: # (time spent on finding the right docs + time spent fetching summaries from disk) - LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-} + # Chat Configs + - HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-} + # Enterprise Edition only - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} @@ -106,19 +104,13 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} # Gen AI Settings (Needed by DanswerBot) - - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} @@ -143,6 +135,7 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} - VESPA_HOST=index + - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} @@ -154,6 +147,7 @@ services: - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} # Indexing Configs + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} - ENABLED_CONNECTOR_TYPES=${ENABLED_CONNECTOR_TYPES:-} - DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-} @@ -248,6 +242,7 @@ services: fi" restart: on-failure environment: + - INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} @@ -288,6 +283,7 @@ services: - INDEXING_ONLY=True # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-1} volumes: # Not necessary, this is just to reduce download time during startup - indexing_huggingface_model_cache:/root/.cache/huggingface/ @@ -355,9 +351,20 @@ services: command: > /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" - + + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml index 250012bd7f5..4364a231f87 100644 --- a/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml +++ b/deployment/docker_compose/docker-compose.prod-no-letsencrypt.yml @@ -1,4 +1,3 @@ -version: "3" services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always env_file: @@ -20,6 +20,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} extra_hosts: - "host.docker.internal:host-gateway" @@ -38,6 +39,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always @@ -47,6 +49,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} extra_hosts: @@ -127,6 +130,7 @@ services: - INDEXING_ONLY=True # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-1} volumes: # Not necessary, this is just to reduce download time during startup - indexing_huggingface_model_cache:/root/.cache/huggingface/ @@ -196,7 +200,22 @@ services: env_file: - .env.nginx +<<<<<<< HEAD +======= + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + +>>>>>>> upstream/main volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.prod.yml b/deployment/docker_compose/docker-compose.prod.yml index e2c2b072f93..113aa42a9f3 100644 --- a/deployment/docker_compose/docker-compose.prod.yml +++ b/deployment/docker_compose/docker-compose.prod.yml @@ -1,4 +1,7 @@ +<<<<<<< HEAD version: "3" +======= +>>>>>>> upstream/main services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +15,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server restart: always env_file: @@ -20,6 +24,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} extra_hosts: - "host.docker.internal:host-gateway" @@ -38,6 +43,7 @@ services: depends_on: - relational_db - index + - cache - inference_model_server - indexing_model_server restart: always @@ -47,6 +53,7 @@ services: - AUTH_TYPE=${AUTH_TYPE:-oidc} - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} extra_hosts: @@ -141,6 +148,7 @@ services: - INDEXING_ONLY=True # Set to debug to get more fine-grained logs - LOG_LEVEL=${LOG_LEVEL:-info} + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-1} volumes: # Not necessary, this is just to reduce download time during startup - indexing_huggingface_model_cache:/root/.cache/huggingface/ @@ -213,7 +221,19 @@ services: max-file: "6" entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" + + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + volumes: + cache_volume: db_volume: vespa_volume: # Created by the container itself diff --git a/deployment/docker_compose/docker-compose.search-testing.yml b/deployment/docker_compose/docker-compose.search-testing.yml index efb387eb083..e9c2aee6667 100644 --- a/deployment/docker_compose/docker-compose.search-testing.yml +++ b/deployment/docker_compose/docker-compose.search-testing.yml @@ -1,4 +1,3 @@ -version: '3' services: api_server: image: danswer/danswer-backend:${IMAGE_TAG:-latest} @@ -12,6 +11,7 @@ services: depends_on: - relational_db - index + - cache restart: always ports: - "8080" @@ -21,6 +21,7 @@ services: - AUTH_TYPE=disabled - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-} @@ -43,6 +44,7 @@ services: depends_on: - relational_db - index + - cache restart: always env_file: - .env_eval @@ -50,6 +52,7 @@ services: - AUTH_TYPE=disabled - POSTGRES_HOST=relational_db - VESPA_HOST=index + - REDIS_HOST=cache - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} @@ -135,6 +138,7 @@ services: - INDEXING_ONLY=True - LOG_LEVEL=${LOG_LEVEL:-debug} - index_model_cache_huggingface:/root/.cache/huggingface/ + - VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-1} logging: driver: json-file options: @@ -200,7 +204,18 @@ services: && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" + cache: + image: redis:7.4-alpine + restart: always + ports: + - '6379:6379' + command: redis-server + volumes: + - cache_volume:/data + + volumes: + cache_volume: db_volume: driver: local driver_opts: diff --git a/deployment/docker_compose/env.multilingual.template b/deployment/docker_compose/env.multilingual.template index e218305153f..1a66dbfbbde 100644 --- a/deployment/docker_compose/env.multilingual.template +++ b/deployment/docker_compose/env.multilingual.template @@ -1,38 +1,8 @@ -# This env template shows how to configure Danswer for multilingual use -# In this case, it is configured for French and English -# To use it, copy it to .env in the docker_compose directory. -# Feel free to combine it with the other templates to suit your needs +# This env template shows how to configure Danswer for custom multilingual use +# Note that for most use cases it will be enough to configure Danswer multilingual purely through the UI +# See "Search Settings" -> "Advanced" for UI options. +# To use it, copy it to .env in the docker_compose directory (or the equivalent environment settings file for your deployment) - -# Rephrase the user query in specified languages using LLM, use comma separated values -MULTILINGUAL_QUERY_EXPANSION="English, French" -# Change the below to suit your specific needs, can be more explicit about the language of the response -LANGUAGE_HINT="IMPORTANT: Respond in the same language as my query!" +# The following is included with the user prompt. Here's one example but feel free to customize it to your needs: +LANGUAGE_HINT="IMPORTANT: ALWAYS RESPOND IN FRENCH! Even if the documents and the user query are in English, your response must be in French." LANGUAGE_CHAT_NAMING_HINT="The name of the conversation must be in the same language as the user query." - -# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small -DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small" - -# The model above is trained with the following prefix for queries and passages to improve retrieval -# by letting the model know which of the two type is currently being embedded -ASYM_QUERY_PREFIX="query: " -ASYM_PASSAGE_PREFIX="passage: " - -# Depends model by model, the one shown above is tuned with this as True -NORMALIZE_EMBEDDINGS="True" - -# Use LLM to determine if chunks are relevant to the query -# May not work well for languages that do not have much training data in the LLM training set -# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on -DISABLE_LLM_DOC_RELEVANCE="True" - -# Enables fine-grained embeddings for better retrieval -# At the cost of indexing speed (~5x slower), query time is same speed -# Since reranking is turned off and multilingual retrieval is generally harder -# it is advised to turn this one on -ENABLE_MULTIPASS_INDEXING="True" - -# Using a stronger LLM will help with multilingual tasks -# Since documents may be in multiple languages, and there are additional instructions to respond -# in the user query's language, it is advised to use the best model possible -GEN_AI_MODEL_VERSION="gpt-4" diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index 818bd1ed1bf..890939deb49 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -7,16 +7,7 @@ WEB_DOMAIN=http://localhost:3000 -# Generative AI settings, uncomment as needed, will work with defaults -GEN_AI_MODEL_PROVIDER=openai -GEN_AI_MODEL_VERSION=gpt-4 -# Provide this as a global default/backup, this can also be set via the UI -#GEN_AI_API_KEY= -# Set to use Azure OpenAI or other services, such as https://danswer.openai.azure.com/ -#GEN_AI_API_ENDPOINT= -# Set up to use a specific API version, such as 2023-09-15-preview (example taken from Azure) -#GEN_AI_API_VERSION= - +# NOTE: Generative AI configurations are done via the UI now # If you want to setup a slack bot to answer questions automatically in Slack # channels it is added to, you must specify the two below. diff --git a/deployment/docker_compose/init-letsencrypt.sh b/deployment/docker_compose/init-letsencrypt.sh index 9eec409fada..66161e4dfbe 100755 --- a/deployment/docker_compose/init-letsencrypt.sh +++ b/deployment/docker_compose/init-letsencrypt.sh @@ -112,5 +112,14 @@ $COMPOSE_CMD -f docker-compose.prod.yml run --name danswer-stack --rm --entrypoi --force-renewal" certbot echo +echo "### Renaming certificate directory if needed ..." +$COMPOSE_CMD -f docker-compose.prod.yml run --name danswer-stack --rm --entrypoint "\ + sh -c 'for domain in $domains; do \ + numbered_dir=\$(find /etc/letsencrypt/live -maxdepth 1 -type d -name \"\$domain-00*\" | sort -r | head -n1); \ + if [ -n \"\$numbered_dir\" ]; then \ + mv \"\$numbered_dir\" /etc/letsencrypt/live/\$domain; \ + fi; \ + done'" certbot + echo "### Reloading nginx ..." $COMPOSE_CMD -f docker-compose.prod.yml -p danswer-stack up --force-recreate -d diff --git a/deployment/helm/.gitignore b/deployment/helm/charts/danswer/.gitignore similarity index 100% rename from deployment/helm/.gitignore rename to deployment/helm/charts/danswer/.gitignore diff --git a/deployment/helm/.helmignore b/deployment/helm/charts/danswer/.helmignore similarity index 100% rename from deployment/helm/.helmignore rename to deployment/helm/charts/danswer/.helmignore diff --git a/deployment/helm/Chart.lock b/deployment/helm/charts/danswer/Chart.lock similarity index 100% rename from deployment/helm/Chart.lock rename to deployment/helm/charts/danswer/Chart.lock diff --git a/deployment/helm/Chart.yaml b/deployment/helm/charts/danswer/Chart.yaml similarity index 93% rename from deployment/helm/Chart.yaml rename to deployment/helm/charts/danswer/Chart.yaml index 7763f33bec5..96336911ed8 100644 --- a/deployment/helm/Chart.yaml +++ b/deployment/helm/charts/danswer/Chart.yaml @@ -22,14 +22,11 @@ dependencies: version: 14.3.1 repository: https://charts.bitnami.com/bitnami condition: postgresql.enabled - - name: vespa + - name: vespa version: 0.2.3 repository: https://unoplat.github.io/vespa-helm-charts condition: vespa.enabled - name: nginx version: 15.14.0 repository: oci://registry-1.docker.io/bitnamicharts - condition: nginx.enabled - - - \ No newline at end of file + condition: nginx.enabled diff --git a/deployment/helm/templates/_helpers.tpl b/deployment/helm/charts/danswer/templates/_helpers.tpl similarity index 100% rename from deployment/helm/templates/_helpers.tpl rename to deployment/helm/charts/danswer/templates/_helpers.tpl diff --git a/deployment/helm/templates/api-deployment.yaml b/deployment/helm/charts/danswer/templates/api-deployment.yaml similarity index 100% rename from deployment/helm/templates/api-deployment.yaml rename to deployment/helm/charts/danswer/templates/api-deployment.yaml diff --git a/deployment/helm/templates/api-hpa.yaml b/deployment/helm/charts/danswer/templates/api-hpa.yaml similarity index 100% rename from deployment/helm/templates/api-hpa.yaml rename to deployment/helm/charts/danswer/templates/api-hpa.yaml diff --git a/deployment/helm/templates/api-service.yaml b/deployment/helm/charts/danswer/templates/api-service.yaml similarity index 100% rename from deployment/helm/templates/api-service.yaml rename to deployment/helm/charts/danswer/templates/api-service.yaml diff --git a/deployment/helm/templates/background-deployment.yaml b/deployment/helm/charts/danswer/templates/background-deployment.yaml similarity index 100% rename from deployment/helm/templates/background-deployment.yaml rename to deployment/helm/charts/danswer/templates/background-deployment.yaml diff --git a/deployment/helm/templates/background-hpa.yaml b/deployment/helm/charts/danswer/templates/background-hpa.yaml similarity index 100% rename from deployment/helm/templates/background-hpa.yaml rename to deployment/helm/charts/danswer/templates/background-hpa.yaml diff --git a/deployment/helm/templates/configmap.yaml b/deployment/helm/charts/danswer/templates/configmap.yaml similarity index 100% rename from deployment/helm/templates/configmap.yaml rename to deployment/helm/charts/danswer/templates/configmap.yaml diff --git a/deployment/helm/templates/danswer-secret.yaml b/deployment/helm/charts/danswer/templates/danswer-secret.yaml similarity index 100% rename from deployment/helm/templates/danswer-secret.yaml rename to deployment/helm/charts/danswer/templates/danswer-secret.yaml diff --git a/deployment/helm/templates/indexing-model-deployment.yaml b/deployment/helm/charts/danswer/templates/indexing-model-deployment.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-deployment.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-deployment.yaml diff --git a/deployment/helm/templates/indexing-model-pvc.yaml b/deployment/helm/charts/danswer/templates/indexing-model-pvc.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-pvc.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-pvc.yaml diff --git a/deployment/helm/templates/indexing-model-service.yaml b/deployment/helm/charts/danswer/templates/indexing-model-service.yaml similarity index 100% rename from deployment/helm/templates/indexing-model-service.yaml rename to deployment/helm/charts/danswer/templates/indexing-model-service.yaml diff --git a/deployment/helm/templates/inference-model-deployment.yaml b/deployment/helm/charts/danswer/templates/inference-model-deployment.yaml similarity index 100% rename from deployment/helm/templates/inference-model-deployment.yaml rename to deployment/helm/charts/danswer/templates/inference-model-deployment.yaml diff --git a/deployment/helm/templates/inference-model-pvc.yaml b/deployment/helm/charts/danswer/templates/inference-model-pvc.yaml similarity index 100% rename from deployment/helm/templates/inference-model-pvc.yaml rename to deployment/helm/charts/danswer/templates/inference-model-pvc.yaml diff --git a/deployment/helm/templates/inference-model-service.yaml b/deployment/helm/charts/danswer/templates/inference-model-service.yaml similarity index 100% rename from deployment/helm/templates/inference-model-service.yaml rename to deployment/helm/charts/danswer/templates/inference-model-service.yaml diff --git a/deployment/helm/templates/nginx-conf.yaml b/deployment/helm/charts/danswer/templates/nginx-conf.yaml similarity index 100% rename from deployment/helm/templates/nginx-conf.yaml rename to deployment/helm/charts/danswer/templates/nginx-conf.yaml diff --git a/deployment/helm/templates/serviceaccount.yaml b/deployment/helm/charts/danswer/templates/serviceaccount.yaml similarity index 100% rename from deployment/helm/templates/serviceaccount.yaml rename to deployment/helm/charts/danswer/templates/serviceaccount.yaml diff --git a/deployment/helm/templates/tests/test-connection.yaml b/deployment/helm/charts/danswer/templates/tests/test-connection.yaml similarity index 100% rename from deployment/helm/templates/tests/test-connection.yaml rename to deployment/helm/charts/danswer/templates/tests/test-connection.yaml diff --git a/deployment/helm/templates/webserver-deployment.yaml b/deployment/helm/charts/danswer/templates/webserver-deployment.yaml similarity index 100% rename from deployment/helm/templates/webserver-deployment.yaml rename to deployment/helm/charts/danswer/templates/webserver-deployment.yaml diff --git a/deployment/helm/templates/webserver-hpa.yaml b/deployment/helm/charts/danswer/templates/webserver-hpa.yaml similarity index 100% rename from deployment/helm/templates/webserver-hpa.yaml rename to deployment/helm/charts/danswer/templates/webserver-hpa.yaml diff --git a/deployment/helm/templates/webserver-service.yaml b/deployment/helm/charts/danswer/templates/webserver-service.yaml similarity index 100% rename from deployment/helm/templates/webserver-service.yaml rename to deployment/helm/charts/danswer/templates/webserver-service.yaml diff --git a/deployment/helm/values.yaml b/deployment/helm/charts/danswer/values.yaml similarity index 94% rename from deployment/helm/values.yaml rename to deployment/helm/charts/danswer/values.yaml index 2167b70438b..4318beef82a 100644 --- a/deployment/helm/values.yaml +++ b/deployment/helm/charts/danswer/values.yaml @@ -84,7 +84,7 @@ postgresql: auth: existingSecret: danswer-secrets secretKeys: - adminPasswordKey: postgres_password #overwriting as postgres typically expects 'postgres-password' + adminPasswordKey: postgres_password # overwriting as postgres typically expects 'postgres-password' nginx: containerPorts: @@ -330,7 +330,7 @@ vespa: affinity: {} -#ingress: +# ingress: # enabled: false # className: "" # annotations: {} @@ -358,47 +358,43 @@ persistence: auth: # for storing smtp, oauth, slack, and other secrets # keys are lowercased version of env vars (e.g. SMTP_USER -> smtp_user) - existingSecret: "" # danswer-secrets + existingSecret: "" # danswer-secrets # optionally override the secret keys to reference in the secret + # this is used to populate the env vars in individual deployments + # the values here reference the keys in secrets below secretKeys: postgres_password: "postgres_password" smtp_pass: "" oauth_client_id: "" oauth_client_secret: "" oauth_cookie_secret: "" - gen_ai_api_key: "" danswer_bot_slack_app_token: "" danswer_bot_slack_bot_token: "" + redis_password: "redis_password" # will be overridden by the existingSecret if set secretName: "danswer-secrets" # set values as strings, they will be base64 encoded + # this is used to populate the secrets yaml secrets: postgres_password: "postgres" smtp_pass: "" oauth_client_id: "" oauth_client_secret: "" oauth_cookie_secret: "" - gen_ai_api_key: "" danswer_bot_slack_app_token: "" danswer_bot_slack_bot_token: "" + redis_password: "password" configMap: AUTH_TYPE: "disabled" # Change this for production uses unless Danswer is only accessible behind VPN SESSION_EXPIRE_TIME_SECONDS: "86400" # 1 Day Default VALID_EMAIL_DOMAINS: "" # Can be something like danswer.ai, as an extra double-check - SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com' - SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587' + SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com' + SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587' SMTP_USER: "" # 'your-email@company.com' # SMTP_PASS: "" # 'your-gmail-password' EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings - GEN_AI_MODEL_PROVIDER: "" - GEN_AI_MODEL_VERSION: "" - FAST_GEN_AI_MODEL_VERSION: "" - # GEN_AI_API_KEY: "" - GEN_AI_API_ENDPOINT: "" - GEN_AI_API_VERSION: "" - GEN_AI_LLM_PROVIDER_TYPE: "" GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" @@ -416,6 +412,7 @@ configMap: # Internet Search Tool BING_API_KEY: "" # Don't change the NLP models unless you know what you're doing + EMBEDDING_BATCH_SIZE: "" DOCUMENT_ENCODER_MODEL: "" NORMALIZE_EMBEDDINGS: "" ASYM_QUERY_PREFIX: "" @@ -424,6 +421,7 @@ configMap: MODEL_SERVER_PORT: "" MIN_THREADS_ML_MODELS: "" # Indexing Configs + VESPA_SEARCHER_THREADS: "" NUM_INDEXING_WORKERS: "" DISABLE_INDEX_UPDATE_ON_SWAP: "" DASK_JOB_CLIENT_ENABLED: "" @@ -452,3 +450,5 @@ configMap: # Shared or Non-backend Related WEB_DOMAIN: "http://localhost:3000" # for web server and api server DOMAIN: "localhost" # for nginx + # Chat Configs + HARD_DELETE_CHATS: "" diff --git a/deployment/kubernetes/api_server-service-deployment.yaml b/deployment/kubernetes/api_server-service-deployment.yaml index eeac5fecc96..ccbbc906d61 100644 --- a/deployment/kubernetes/api_server-service-deployment.yaml +++ b/deployment/kubernetes/api_server-service-deployment.yaml @@ -52,6 +52,11 @@ spec: secretKeyRef: name: danswer-secrets key: google_oauth_client_secret + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/deployment/kubernetes/background-deployment.yaml b/deployment/kubernetes/background-deployment.yaml index 18521b0f5ad..1a6ef61c104 100644 --- a/deployment/kubernetes/background-deployment.yaml +++ b/deployment/kubernetes/background-deployment.yaml @@ -19,6 +19,12 @@ spec: command: ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] # There are some extra values since this is shared between services # There are no conflicts though, extra env variables are simply ignored + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 907fae1c836..b5db45434e8 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -14,13 +14,6 @@ data: SMTP_PASS: "" # 'your-gmail-password' EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings - GEN_AI_MODEL_PROVIDER: "" - GEN_AI_MODEL_VERSION: "" - FAST_GEN_AI_MODEL_VERSION: "" - GEN_AI_API_KEY: "" - GEN_AI_API_ENDPOINT: "" - GEN_AI_API_VERSION: "" - GEN_AI_LLM_PROVIDER_TYPE: "" GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" MAX_CHUNKS_FED_TO_CHAT: "" @@ -38,9 +31,11 @@ data: # Other Services POSTGRES_HOST: "relational-db-service" VESPA_HOST: "document-index-service" + REDIS_HOST: "redis-service" # Internet Search Tool BING_API_KEY: "" # Don't change the NLP models unless you know what you're doing + EMBEDDING_BATCH_SIZE: "" DOCUMENT_ENCODER_MODEL: "" NORMALIZE_EMBEDDINGS: "" ASYM_QUERY_PREFIX: "" @@ -51,6 +46,7 @@ data: INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service" MIN_THREADS_ML_MODELS: "" # Indexing Configs + VESPA_SEARCHER_THREADS: "" NUM_INDEXING_WORKERS: "" ENABLED_CONNECTOR_TYPES: "" DISABLE_INDEX_UPDATE_ON_SWAP: "" @@ -82,3 +78,5 @@ data: INTERNAL_URL: "http://api-server-service:80" # for web server WEB_DOMAIN: "http://localhost:3000" # for web server and api server DOMAIN: "localhost" # for nginx + # Chat Configs + HARD_DELETE_CHATS: "" diff --git a/deployment/kubernetes/redis-service-deployment.yaml b/deployment/kubernetes/redis-service-deployment.yaml new file mode 100644 index 00000000000..ab5113e5f49 --- /dev/null +++ b/deployment/kubernetes/redis-service-deployment.yaml @@ -0,0 +1,41 @@ +apiVersion: v1 +kind: Service +metadata: + name: redis-service +spec: + selector: + app: redis + ports: + - name: redis + protocol: TCP + port: 6379 + targetPort: 6379 + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: redis-deployment +spec: + replicas: 1 + selector: + matchLabels: + app: redis + template: + metadata: + labels: + app: redis + spec: + containers: + - name: redis + image: redis:7.4-alpine + ports: + - containerPort: 6379 + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password + command: ["redis-server"] + args: ["--requirepass", "$(REDIS_PASSWORD)"] diff --git a/deployment/kubernetes/secrets.yaml b/deployment/kubernetes/secrets.yaml index c135a29f676..d4cc9e2a739 100644 --- a/deployment/kubernetes/secrets.yaml +++ b/deployment/kubernetes/secrets.yaml @@ -8,4 +8,6 @@ data: postgres_user: cG9zdGdyZXM= # "postgres" base64 encoded postgres_password: cGFzc3dvcmQ= # "password" base64 encoded google_oauth_client_id: ZXhhbXBsZS1jbGllbnQtaWQ= # "example-client-id" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 - google_oauth_client_secret: example_google_oauth_secret # "example-client-secret" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 + google_oauth_client_secret: ZXhhbXBsZV9nb29nbGVfb2F1dGhfc2VjcmV0 # "example-client-secret" base64 encoded. You will need to provide this, use echo -n "your-client-id" | base64 + redis_password: cGFzc3dvcmQ= # "password" base64 encoded + \ No newline at end of file diff --git a/deployment/kubernetes/web_server-service-deployment.yaml b/deployment/kubernetes/web_server-service-deployment.yaml index b19b8e37986..b54c1b7f3d0 100644 --- a/deployment/kubernetes/web_server-service-deployment.yaml +++ b/deployment/kubernetes/web_server-service-deployment.yaml @@ -33,6 +33,12 @@ spec: - containerPort: 3000 # There are some extra values since this is shared between services # There are no conflicts though, extra env variables are simply ignored + env: + - name: REDIS_PASSWORD + valueFrom: + secretKeyRef: + name: danswer-secrets + key: redis_password envFrom: - configMapRef: name: env-configmap diff --git a/web/Dockerfile b/web/Dockerfile index 4ffced0da49..710cf653f25 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -58,6 +58,7 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} + RUN npx next build # Step 2. Production image, copy all the files and run next diff --git a/web/next.config.js b/web/next.config.js index 1586af8d178..92812c513b7 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -8,47 +8,6 @@ const version = env_version || package_version; const nextConfig = { output: "standalone", swcMinify: true, - rewrites: async () => { - // In production, something else (nginx in the one box setup) should take - // care of this rewrite. TODO (chris): better support setups where - // web_server and api_server are on different machines. - if (process.env.NODE_ENV === "production") return []; - - return [ - { - source: "/api/:path*", - destination: "http://127.0.0.1:8080/:path*", // Proxy to Backend - }, - ]; - }, - redirects: async () => { - // In production, something else (nginx in the one box setup) should take - // care of this redirect. TODO (chris): better support setups where - // web_server and api_server are on different machines. - const defaultRedirects = []; - - if (process.env.NODE_ENV === "production") return defaultRedirects; - - return defaultRedirects.concat([ - { - source: "/api/chat/send-message:params*", - destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend - permanent: true, - }, - { - source: "/api/query/stream-answer-with-quote:params*", - destination: - "http://127.0.0.1:8080/query/stream-answer-with-quote:params*", // Proxy to Backend - permanent: true, - }, - { - source: "/api/query/stream-query-validation:params*", - destination: - "http://127.0.0.1:8080/query/stream-query-validation:params*", // Proxy to Backend - permanent: true, - }, - ]); - }, publicRuntimeConfig: { version, }, diff --git a/web/package-lock.json b/web/package-lock.json index 48ac21d6477..338cf0a9f0f 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -2555,11 +2555,11 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -4061,9 +4061,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dependencies": { "to-regex-range": "^5.0.1" }, diff --git a/web/public/LiteLLM.jpg b/web/public/LiteLLM.jpg new file mode 100644 index 00000000000..d6a77b2d105 Binary files /dev/null and b/web/public/LiteLLM.jpg differ diff --git a/web/src/app/admin/add-connector/page.tsx b/web/src/app/admin/add-connector/page.tsx index bf7032b5f90..8d73131e69a 100644 --- a/web/src/app/admin/add-connector/page.tsx +++ b/web/src/app/admin/add-connector/page.tsx @@ -112,7 +112,7 @@ export default function Page() { value={searchTerm} onChange={(e) => setSearchTerm(e.target.value)} onKeyDown={handleKeyPress} - className="flex mt-2 max-w-sm h-9 w-full rounded-md border-2 border border-input bg-transparent px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" + className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" /> {Object.entries(categorizedSources) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index d478922e516..1c32beae14a 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -928,9 +928,9 @@ export function AssistantEditor({ { const value = e.target.value; if ( @@ -1184,20 +1184,16 @@ export function AssistantEditor({ /> - {isPaidEnterpriseFeaturesEnabled && - userGroups && - userGroups.length > 0 && ( - - )} + )} diff --git a/web/src/app/admin/bot/page.tsx b/web/src/app/admin/bot/page.tsx index 14f270ee9bc..c3ef70ccbf6 100644 --- a/web/src/app/admin/bot/page.tsx +++ b/web/src/app/admin/bot/page.tsx @@ -66,7 +66,7 @@ const SlackBotConfigsTable = ({ Channels - Persona + Assistant Document Sets Delete diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 80ff1f456b9..5252e6aede5 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -28,6 +28,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; +import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; function customConfigProcessing(customConfigsList: [string, string][]) { const customConfig: { [key: string]: string } = {}; @@ -209,9 +210,9 @@ export function CustomLLMProviderUpdateForm({ setSubmitting(false); }} > - {({ values, setFieldValue }) => { + {(formikProps) => { return ( -
+ ) => (
- {values.custom_config_list.map((_, index) => { + {formikProps.values.custom_config_list.map((_, index) => { return (
List the individual models that you want to make available as @@ -419,64 +420,12 @@ export function CustomLLMProviderUpdateForm({ /> {showAdvancedOptions && ( - <> - {isPaidEnterpriseFeaturesEnabled && userGroups && ( - <> - - - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to this - LLM Provider. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
{userGroup.name}
-
-
- ); - })} -
-
- )} - - )} - + )}
diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index 49d95d096f5..f461ffbe889 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -24,6 +24,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; +import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; export function LLMProviderUpdateForm({ llmProviderDescriptor, @@ -31,11 +32,13 @@ export function LLMProviderUpdateForm({ existingLlmProvider, shouldMarkAsDefault, setPopup, + hideAdvanced, }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor; onClose: () => void; existingLlmProvider?: FullLLMProvider; shouldMarkAsDefault?: boolean; + hideAdvanced?: boolean; setPopup?: (popup: PopupSpec) => void; }) { const { mutate } = useSWRConfig(); @@ -52,7 +55,7 @@ export function LLMProviderUpdateForm({ // Define the initial values based on the provider's requirements const initialValues = { - name: existingLlmProvider?.name ?? "", + name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""), api_key: existingLlmProvider?.api_key ?? "", api_base: existingLlmProvider?.api_base ?? "", api_version: existingLlmProvider?.api_version ?? "", @@ -217,18 +220,21 @@ export function LLMProviderUpdateForm({ setSubmitting(false); }} > - {({ values, setFieldValue }) => ( - - + {(formikProps) => ( + + {!hideAdvanced && ( + + )} {llmProviderDescriptor.api_key_required && ( (
))} - - - {llmProviderDescriptor.llm_names.length > 0 ? ( - ({ - name: getDisplayNameForModel(name), - value: name, - }))} - maxHeight="max-h-56" - /> - ) : ( - - )} + {!hideAdvanced && ( + <> + - {llmProviderDescriptor.llm_names.length > 0 ? ( - 0 ? ( + ({ + name: getDisplayNameForModel(name), + value: name, + }))} + maxHeight="max-h-56" + /> + ) : ( + + )} + + {llmProviderDescriptor.llm_names.length > 0 ? ( + ({ - name: getDisplayNameForModel(name), - value: name, - }))} - includeDefault - maxHeight="max-h-56" - /> - ) : ( - ({ + name: getDisplayNameForModel(name), + value: name, + }))} + includeDefault + maxHeight="max-h-56" + /> + ) : ( + - )} - - + label="[Optional] Fast Model" + placeholder="E.g. gpt-4" + /> + )} - {llmProviderDescriptor.name != "azure" && ( - - )} + - {showAdvancedOptions && ( - <> - {llmProviderDescriptor.llm_names.length > 0 && ( -
- ({ - value: name, - label: getDisplayNameForModel(name), - }))} - onChange={(selected) => - setFieldValue("display_model_names", selected) - } - /> -
+ {llmProviderDescriptor.name != "azure" && ( + )} - {isPaidEnterpriseFeaturesEnabled && userGroups && ( + {showAdvancedOptions && ( <> - - - {userGroups && userGroups.length > 0 && !values.is_public && ( -
- - Select which User Groups should have access to this LLM - Provider. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
{userGroup.name}
-
-
- ); - })} -
+ {llmProviderDescriptor.llm_names.length > 0 && ( +
+ ({ + value: name, + label: getDisplayNameForModel(name), + }) + )} + onChange={(selected) => + formikProps.setFieldValue( + "display_model_names", + selected + ) + } + />
)} + + )} @@ -432,6 +404,27 @@ export function LLMProviderUpdateForm({ return; } + // If the deleted provider was the default, set the first remaining provider as default + const remainingProvidersResponse = await fetch( + LLM_PROVIDERS_ADMIN_URL + ); + if (remainingProvidersResponse.ok) { + const remainingProviders = + await remainingProvidersResponse.json(); + + if (remainingProviders.length > 0) { + const setDefaultResponse = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`, + { + method: "POST", + } + ); + if (!setDefaultResponse.ok) { + console.error("Failed to set new default provider"); + } + } + } + mutate(LLM_PROVIDERS_ADMIN_URL); onClose(); }} diff --git a/web/src/app/admin/configuration/llm/constants.ts b/web/src/app/admin/configuration/llm/constants.ts index a265f4a2b2d..d7e3449b34d 100644 --- a/web/src/app/admin/configuration/llm/constants.ts +++ b/web/src/app/admin/configuration/llm/constants.ts @@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; export const EMBEDDING_PROVIDERS_ADMIN_URL = "/api/admin/embedding/embedding-provider"; + +export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding"; diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 2d0d49196b4..33fa94d7f15 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -1,3 +1,13 @@ +import { + AnthropicIcon, + AWSIcon, + AzureIcon, + CPUIcon, + OpenAIIcon, + OpenSourceIcon, +} from "@/components/icons/icons"; +import { FaRobot } from "react-icons/fa"; + export interface CustomConfigKey { name: string; description: string | null; @@ -53,3 +63,18 @@ export interface LLMProviderDescriptor { groups: number[]; display_model_names: string[] | null; } + +export const getProviderIcon = (providerName: string) => { + switch (providerName) { + case "openai": + return OpenAIIcon; + case "anthropic": + return AnthropicIcon; + case "bedrock": + return AWSIcon; + case "azure": + return AzureIcon; + default: + return CPUIcon; + } +}; diff --git a/web/src/app/admin/configuration/search/UpgradingPage.tsx b/web/src/app/admin/configuration/search/UpgradingPage.tsx index da379656336..6e41f4cf42d 100644 --- a/web/src/app/admin/configuration/search/UpgradingPage.tsx +++ b/web/src/app/admin/configuration/search/UpgradingPage.tsx @@ -1,9 +1,13 @@ import { ThreeDotsLoader } from "@/components/Loading"; import { Modal } from "@/components/Modal"; import { errorHandlingFetcher } from "@/lib/fetcher"; -import { ConnectorIndexingStatus } from "@/lib/types"; +import { + ConnectorIndexingStatus, + FailedConnectorIndexingStatus, + ValidStatuses, +} from "@/lib/types"; import { Button, Text, Title } from "@tremor/react"; -import { useState } from "react"; +import { useMemo, useState } from "react"; import useSWR, { mutate } from "swr"; import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable"; import { ErrorCallout } from "@/components/ErrorCallout"; @@ -12,6 +16,8 @@ import { HostedEmbeddingModel, } from "../../../../components/embedding/interfaces"; import { Connector } from "@/lib/connectors/connectors"; +import { FailedReIndexAttempts } from "@/components/embedding/FailedReIndexAttempts"; +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; export default function UpgradingPage({ futureEmbeddingModel, @@ -20,6 +26,7 @@ export default function UpgradingPage({ }) { const [isCancelling, setIsCancelling] = useState(false); + const { setPopup, popup } = usePopup(); const { data: connectors } = useSWR[]>( "/api/manage/connector", errorHandlingFetcher, @@ -35,6 +42,14 @@ export default function UpgradingPage({ { refreshInterval: 5000 } // 5 seconds ); + const { data: failedIndexingStatus } = useSWR< + FailedConnectorIndexingStatus[] + >( + "/api/manage/admin/connector/failed-indexing-status?secondary_index=true", + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const onCancel = async () => { const response = await fetch("/api/search-settings/cancel-new-embedding", { method: "POST", @@ -48,9 +63,37 @@ export default function UpgradingPage({ } setIsCancelling(false); }; + const statusOrder: Record = { + failed: 0, + completed_with_errors: 1, + not_started: 2, + in_progress: 3, + success: 4, + }; + + const sortedReindexingProgress = useMemo(() => { + return [...(ongoingReIndexingStatus || [])].sort((a, b) => { + const statusComparison = + statusOrder[a.latest_index_attempt?.status || "not_started"] - + statusOrder[b.latest_index_attempt?.status || "not_started"]; + + if (statusComparison !== 0) { + return statusComparison; + } + + return ( + (a.latest_index_attempt?.id || 0) - (b.latest_index_attempt?.id || 0) + ); + }); + }, [ongoingReIndexingStatus]); + + if (!failedIndexingStatus) { + return
No failed index attempts
; + } return ( <> + {popup} {isCancelling && ( setIsCancelling(false)} @@ -90,6 +133,12 @@ export default function UpgradingPage({ > Cancel + {failedIndexingStatus.length > 0 && ( + + )} The table below shows the re-indexing progress of all existing @@ -101,9 +150,9 @@ export default function UpgradingPage({ {isLoadingOngoingReIndexingStatus ? ( - ) : ongoingReIndexingStatus ? ( + ) : sortedReindexingProgress ? ( ) : ( diff --git a/web/src/app/admin/connector/[ccPairId]/DeletionErrorStatus.tsx b/web/src/app/admin/connector/[ccPairId]/DeletionErrorStatus.tsx new file mode 100644 index 00000000000..dbeb28cf631 --- /dev/null +++ b/web/src/app/admin/connector/[ccPairId]/DeletionErrorStatus.tsx @@ -0,0 +1,25 @@ +import { FiInfo } from "react-icons/fi"; + +export default function DeletionErrorStatus({ + deletion_failure_message, +}: { + deletion_failure_message: string; +}) { + return ( +
+
+

Deletion Error

+
+ +
+ This error occurred while attempting to delete the connector. You + may re-attempt a deletion by clicking the "Delete" button. +
+
+
+
+

{deletion_failure_message}

+
+
+ ); +} diff --git a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx index b9861a29759..e8d8822bbcc 100644 --- a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx +++ b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx @@ -1,5 +1,6 @@ "use client"; +import { useEffect, useRef } from "react"; import { Table, TableHead, @@ -8,31 +9,173 @@ import { TableBody, TableCell, Text, - Button, - Divider, } from "@tremor/react"; -import { IndexAttemptStatus } from "@/components/Status"; import { CCPairFullInfo } from "./types"; +import { IndexAttemptStatus } from "@/components/Status"; import { useState } from "react"; import { PageSelector } from "@/components/PageSelector"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { buildCCPairInfoUrl } from "./lib"; import { localizeAndPrettify } from "@/lib/time"; import { getDocsProcessedPerMinute } from "@/lib/indexAttempt"; -import { Modal } from "@/components/Modal"; -import { CheckmarkIcon, CopyIcon, SearchIcon } from "@/components/icons/icons"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { InfoIcon, SearchIcon } from "@/components/icons/icons"; import Link from "next/link"; import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; +import { PaginatedIndexAttempts } from "./types"; +import { useRouter } from "next/navigation"; +import { Tooltip } from "@/components/tooltip/Tooltip"; +// This is the number of index attempts to display per page const NUM_IN_PAGE = 8; +// This is the number of pages to fetch at a time +const BATCH_SIZE = 8; export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { - const [page, setPage] = useState(1); const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState< number | null >(null); - const indexAttemptToDisplayTraceFor = ccPair.index_attempts.find( + + const totalPages = Math.ceil(ccPair.number_of_index_attempts / NUM_IN_PAGE); + + const router = useRouter(); + const [page, setPage] = useState(() => { + if (typeof window !== "undefined") { + const urlParams = new URLSearchParams(window.location.search); + return parseInt(urlParams.get("page") || "1", 10); + } + return 1; + }); + + const [currentPageData, setCurrentPageData] = + useState(null); + const [currentPageError, setCurrentPageError] = useState(null); + const [isCurrentPageLoading, setIsCurrentPageLoading] = useState(false); + + // This is a cache of the data for each "batch" which is a set of pages + const [cachedBatches, setCachedBatches] = useState<{ + [key: number]: PaginatedIndexAttempts[]; + }>({}); + + // This is a set of the batches that are currently being fetched + // we use it to avoid duplicate requests + const ongoingRequestsRef = useRef>(new Set()); + + const batchRetrievalUrlBuilder = (batchNum: number) => + `${buildCCPairInfoUrl(ccPair.id)}/index-attempts?page=${batchNum}&page_size=${BATCH_SIZE * NUM_IN_PAGE}`; + + // This fetches and caches the data for a given batch number + const fetchBatchData = async (batchNum: number) => { + if (ongoingRequestsRef.current.has(batchNum)) return; + ongoingRequestsRef.current.add(batchNum); + + try { + const response = await fetch(batchRetrievalUrlBuilder(batchNum + 1)); + if (!response.ok) { + throw new Error("Failed to fetch data"); + } + const data = await response.json(); + + const newBatchData: PaginatedIndexAttempts[] = []; + for (let i = 0; i < BATCH_SIZE; i++) { + const startIndex = i * NUM_IN_PAGE; + const endIndex = startIndex + NUM_IN_PAGE; + const pageIndexAttempts = data.index_attempts.slice( + startIndex, + endIndex + ); + newBatchData.push({ + ...data, + index_attempts: pageIndexAttempts, + }); + } + + setCachedBatches((prev) => ({ + ...prev, + [batchNum]: newBatchData, + })); + } catch (error) { + setCurrentPageError( + error instanceof Error ? error : new Error("An error occurred") + ); + } finally { + ongoingRequestsRef.current.delete(batchNum); + } + }; + + // This fetches and caches the data for the current batch and the next and previous batches + useEffect(() => { + const batchNum = Math.floor((page - 1) / BATCH_SIZE); + + if (!cachedBatches[batchNum]) { + setIsCurrentPageLoading(true); + fetchBatchData(batchNum); + } else { + setIsCurrentPageLoading(false); + } + + const nextBatchNum = Math.min( + batchNum + 1, + Math.ceil(totalPages / BATCH_SIZE) - 1 + ); + if (!cachedBatches[nextBatchNum]) { + fetchBatchData(nextBatchNum); + } + + const prevBatchNum = Math.max(batchNum - 1, 0); + if (!cachedBatches[prevBatchNum]) { + fetchBatchData(prevBatchNum); + } + + // Always fetch the first batch if it's not cached + if (!cachedBatches[0]) { + fetchBatchData(0); + } + }, [ccPair.id, page, cachedBatches, totalPages]); + + // This updates the data on the current page + useEffect(() => { + const batchNum = Math.floor((page - 1) / BATCH_SIZE); + const batchPageNum = (page - 1) % BATCH_SIZE; + + if (cachedBatches[batchNum] && cachedBatches[batchNum][batchPageNum]) { + setCurrentPageData(cachedBatches[batchNum][batchPageNum]); + setIsCurrentPageLoading(false); + } else { + setIsCurrentPageLoading(true); + } + }, [page, cachedBatches]); + + // This updates the page number and manages the URL + const updatePage = (newPage: number) => { + setPage(newPage); + router.push(`/admin/connector/${ccPair.id}?page=${newPage}`, { + scroll: false, + }); + window.scrollTo({ + top: 0, + left: 0, + behavior: "smooth", + }); + }; + + if (isCurrentPageLoading || !currentPageData) { + return ; + } + + if (currentPageError) { + return ( + + ); + } + + // This is the index attempt that the user wants to view the trace for + const indexAttemptToDisplayTraceFor = currentPageData?.index_attempts?.find( (indexAttempt) => indexAttempt.id === indexAttemptTracePopupId ); - const [copyClicked, setCopyClicked] = useState(false); return ( <> @@ -50,106 +193,109 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { Time Started Status New Doc Cnt - Total Doc Cnt + +
+ + + Total Doc Cnt + + + +
+
Error Message - {ccPair.index_attempts - .slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page) - .map((indexAttempt) => { - const docsPerMinute = - getDocsProcessedPerMinute(indexAttempt)?.toFixed(2); - return ( - - - {indexAttempt.time_started - ? localizeAndPrettify(indexAttempt.time_started) - : "-"} - - - - {docsPerMinute && ( -
- {docsPerMinute} docs / min -
- )} -
- -
-
-
{indexAttempt.new_docs_indexed}
- {indexAttempt.docs_removed_from_index > 0 && ( -
- (also removed {indexAttempt.docs_removed_from_index}{" "} - docs that were detected as deleted in the source) -
- )} -
+ {currentPageData.index_attempts.map((indexAttempt) => { + const docsPerMinute = + getDocsProcessedPerMinute(indexAttempt)?.toFixed(2); + return ( + + + {indexAttempt.time_started + ? localizeAndPrettify(indexAttempt.time_started) + : "-"} + + + + {docsPerMinute && ( +
+ {docsPerMinute} docs / min
-
- {indexAttempt.total_docs_indexed} - -
- {indexAttempt.error_count > 0 && ( - - - -  View Errors - - + )} + + +
+
+
{indexAttempt.new_docs_indexed}
+ {indexAttempt.docs_removed_from_index > 0 && ( +
+ (also removed {indexAttempt.docs_removed_from_index}{" "} + docs that were detected as deleted in the source) +
)} +
+
+
+ {indexAttempt.total_docs_indexed} + +
+ {indexAttempt.error_count > 0 && ( + + + +  View Errors + + + )} - {indexAttempt.status === "success" && ( + {indexAttempt.status === "success" && ( + + {"-"} + + )} + + {indexAttempt.status === "failed" && + indexAttempt.error_msg && ( - {"-"} + {indexAttempt.error_msg} )} - {indexAttempt.status === "failed" && - indexAttempt.error_msg && ( - - {indexAttempt.error_msg} - - )} - - {indexAttempt.full_exception_trace && ( -
{ - setIndexAttemptTracePopupId(indexAttempt.id); - }} - className="mt-2 text-link cursor-pointer select-none" - > - View Full Trace -
- )} -
-
- - ); - })} + {indexAttempt.full_exception_trace && ( +
{ + setIndexAttemptTracePopupId(indexAttempt.id); + }} + className="mt-2 text-link cursor-pointer select-none" + > + View Full Trace +
+ )} +
+
+
+ ); + })} - {ccPair.index_attempts.length > NUM_IN_PAGE && ( + {totalPages > 1 && (
{ - setPage(newPage); - window.scrollTo({ - top: 0, - left: 0, - behavior: "smooth", - }); - }} + onPageChange={updatePage} />
diff --git a/web/src/app/admin/connector/[ccPairId]/page.tsx b/web/src/app/admin/connector/[ccPairId]/page.tsx index f5da225a867..b18ab24f103 100644 --- a/web/src/app/admin/connector/[ccPairId]/page.tsx +++ b/web/src/app/admin/connector/[ccPairId]/page.tsx @@ -1,7 +1,6 @@ "use client"; import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; import { CCPairStatus } from "@/components/Status"; import { BackButton } from "@/components/BackButton"; import { Button, Divider, Title } from "@tremor/react"; @@ -11,7 +10,6 @@ import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster"; import { DeletionButton } from "./DeletionButton"; import { ErrorCallout } from "@/components/ErrorCallout"; import { ReIndexButton } from "./ReIndexButton"; -import { isCurrentlyDeleting } from "@/lib/documentDeletion"; import { ValidSources } from "@/lib/types"; import useSWR, { mutate } from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; @@ -24,6 +22,7 @@ import { useEffect, useRef, useState } from "react"; import { CheckmarkIcon, EditIcon, XIcon } from "@/components/icons/icons"; import { usePopup } from "@/components/admin/connectors/Popup"; import { updateConnectorCredentialPairName } from "@/lib/connector"; +import DeletionErrorStatus from "./DeletionErrorStatus"; // since the uploaded files are cleaned up after some period of time // re-indexing will not work for the file connector. Also, it would not @@ -86,24 +85,13 @@ function Main({ ccPairId }: { ccPairId: number }) { return ( ); } - const lastIndexAttempt = ccPair.index_attempts[0]; const isDeleting = ccPair.status === ConnectorCredentialPairStatus.DELETING; - // figure out if we need to artificially deflate the number of docs indexed. - // This is required since the total number of docs indexed by a CC Pair is - // updated before the new docs for an indexing attempt. If we don't do this, - // there is a mismatch between these two numbers which may confuse users. - const totalDocsIndexed = - lastIndexAttempt?.status === "in_progress" && - ccPair.index_attempts.length === 1 - ? lastIndexAttempt.total_docs_indexed - : ccPair.num_docs_indexed; - const refresh = () => { mutate(buildCCPairInfoUrl(ccPairId)); }; @@ -182,13 +170,13 @@ function Main({ ccPairId }: { ccPairId: number }) { )}
Total Documents Indexed:{" "} - {totalDocsIndexed} + {ccPair.num_docs_indexed}
{!ccPair.is_editable_for_current_user && (
@@ -197,6 +185,17 @@ function Main({ ccPairId }: { ccPairId: number }) { : "This connector belongs to groups where you don't have curator permissions, so it's not editable."}
)} + + {ccPair.deletion_failure_message && + ccPair.status === ConnectorCredentialPairStatus.DELETING && ( + <> +
+ + + )} + {credentialTemplates[ccPair.connector.source] && ccPair.is_editable_for_current_user && ( <> diff --git a/web/src/app/admin/connector/[ccPairId]/types.ts b/web/src/app/admin/connector/[ccPairId]/types.ts index 1cc43311e21..55bbe955730 100644 --- a/web/src/app/admin/connector/[ccPairId]/types.ts +++ b/web/src/app/admin/connector/[ccPairId]/types.ts @@ -1,6 +1,10 @@ import { Connector } from "@/lib/connectors/connectors"; import { Credential } from "@/lib/connectors/credentials"; -import { DeletionAttemptSnapshot, IndexAttemptSnapshot } from "@/lib/types"; +import { + DeletionAttemptSnapshot, + IndexAttemptSnapshot, + ValidStatuses, +} from "@/lib/types"; export enum ConnectorCredentialPairStatus { ACTIVE = "ACTIVE", @@ -15,8 +19,16 @@ export interface CCPairFullInfo { num_docs_indexed: number; connector: Connector; credential: Credential; - index_attempts: IndexAttemptSnapshot[]; + number_of_index_attempts: number; + last_index_attempt_status: ValidStatuses | null; latest_deletion_attempt: DeletionAttemptSnapshot | null; is_public: boolean; is_editable_for_current_user: boolean; + deletion_failure_message: string | null; +} + +export interface PaginatedIndexAttempts { + index_attempts: IndexAttemptSnapshot[]; + page: number; + total_pages: number; } diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index dd8d19ca720..af30479aaa0 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -1,19 +1,17 @@ "use client"; -import * as Yup from "yup"; -import { TrashIcon } from "@/components/icons/icons"; import { errorHandlingFetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { Card, Divider, Title } from "@tremor/react"; +import { Card, Title } from "@tremor/react"; import { AdminPageTitle } from "@/components/admin/Title"; import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; import { usePopup } from "@/components/admin/connectors/Popup"; import { useFormContext } from "@/components/context/FormContext"; import { getSourceDisplayName } from "@/lib/sources"; import { SourceIcon } from "@/components/SourceIcon"; -import { useRef, useState, useEffect } from "react"; +import { useState } from "react"; import { submitConnector } from "@/components/admin/connectors/ConnectorForm"; import { deleteCredential, linkCredential } from "@/lib/credential"; import { submitFiles } from "./pages/utils/files"; @@ -22,44 +20,43 @@ import AdvancedFormPage from "./pages/Advanced"; import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm"; import CreateCredential from "@/components/credentials/actions/CreateCredential"; import ModifyCredential from "@/components/credentials/actions/ModifyCredential"; -import { ValidSources } from "@/lib/types"; +import { ConfigurableSources, ValidSources } from "@/lib/types"; import { Credential, credentialTemplates } from "@/lib/connectors/credentials"; import { ConnectionConfiguration, connectorConfigs, + createConnectorInitialValues, + createConnectorValidationSchema, } from "@/lib/connectors/connectors"; import { Modal } from "@/components/Modal"; -import { ArrowRight } from "@phosphor-icons/react"; -import { ArrowLeft } from "@phosphor-icons/react/dist/ssr"; -import { FiPlus } from "react-icons/fi"; import GDriveMain from "./pages/gdrive/GoogleDrivePage"; import { GmailMain } from "./pages/gmail/GmailPage"; import { useGmailCredentials, useGoogleDriveCredentials, } from "./pages/utils/hooks"; -import { Formik, FormikProps } from "formik"; -import { - IsPublicGroupSelector, - IsPublicGroupSelectorFormType, -} from "@/components/IsPublicGroupSelector"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; -import { AdminBooleanFormField } from "@/components/credentials/CredentialFields"; - -export type AdvancedConfigFinal = { - pruneFreq: number | null; - refreshFreq: number | null; - indexingStart: Date | null; -}; +import { Formik } from "formik"; +import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; +import NavigationRow from "./NavigationRow"; + +export interface AdvancedConfig { + refreshFreq: number; + pruneFreq: number; + indexingStart: string; +} export default function AddConnector({ connector, }: { - connector: ValidSources; + connector: ConfigurableSources; }) { + // State for managing credentials and files const [currentCredential, setCurrentCredential] = useState | null>(null); + const [selectedFiles, setSelectedFiles] = useState([]); + const [createConnectorToggle, setCreateConnectorToggle] = useState(false); + // Fetch credentials data const { data: credentials } = useSWR[]>( buildSimilarCredentialInfoURL(connector), errorHandlingFetcher, @@ -71,74 +68,27 @@ export default function AddConnector({ errorHandlingFetcher, { refreshInterval: 5000 } ); - const [selectedFiles, setSelectedFiles] = useState([]); + // Get credential template and configuration const credentialTemplate = credentialTemplates[connector]; - - const { - setFormStep, - setAllowAdvanced, - setAlowCreate, - formStep, - nextFormStep, - prevFormStep, - } = useFormContext(); - - const { popup, setPopup } = usePopup(); - const configuration: ConnectionConfiguration = connectorConfigs[connector]; - const [formValues, setFormValues] = useState< - Record & IsPublicGroupSelectorFormType - >({ - name: "", - groups: [], - is_public: false, - ...configuration.values.reduce( - (acc, field) => { - if (field.type === "list") { - acc[field.name] = field.default || []; - } else if (field.type === "checkbox") { - acc[field.name] = field.default || false; - } else if (field.default !== undefined) { - acc[field.name] = field.default; - } - return acc; - }, - {} as { [record: string]: any } - ), - }); - - const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); - - // Default to 10 minutes unless otherwise specified - const defaultAdvancedSettings = { - refreshFreq: formValues.overrideDefaultFreq || 10, - pruneFreq: 30, - indexingStart: null as string | null, - }; - const [advancedSettings, setAdvancedSettings] = useState( - defaultAdvancedSettings - ); - - const [createConnectorToggle, setCreateConnectorToggle] = useState(false); - const formRef = useRef>(null); - - const [isFormValid, setIsFormValid] = useState(false); - - const handleFormStatusChange = (isValid: boolean) => { - setIsFormValid(isValid || connector == "file"); - }; + // Form context and popup management + const { setFormStep, setAlowCreate, formStep, nextFormStep, prevFormStep } = + useFormContext(); + const { popup, setPopup } = usePopup(); + // Hooks for Google Drive and Gmail credentials const { liveGDriveCredential } = useGoogleDriveCredentials(); - const { liveGmailCredential } = useGmailCredentials(); + // Check if credential is activated const credentialActivated = (connector === "google_drive" && liveGDriveCredential) || (connector === "gmail" && liveGmailCredential) || currentCredential; + // Check if there are no credentials const noCredentials = credentialTemplate == null; if (noCredentials && 1 != formStep) { @@ -149,164 +99,20 @@ export default function AddConnector({ setFormStep(Math.min(formStep, 0)); } - const resetAdvancedConfigs = (formikProps: FormikProps) => { - formikProps.resetForm({ values: defaultAdvancedSettings }); - setAdvancedSettings(defaultAdvancedSettings); - }; - const convertStringToDateTime = (indexingStart: string | null) => { return indexingStart ? new Date(indexingStart) : null; }; - const createConnector = async () => { - const { - name, - groups, - is_public: isPublic, - ...connector_specific_config - } = formValues; - const { pruneFreq, indexingStart, refreshFreq } = advancedSettings; - - // Apply transforms from connectors.ts configuration - const transformedConnectorSpecificConfig = Object.entries( - connector_specific_config - ).reduce( - (acc, [key, value]) => { - const matchingConfigValue = configuration.values.find( - (configValue) => configValue.name === key - ); - if ( - matchingConfigValue && - "transform" in matchingConfigValue && - matchingConfigValue.transform - ) { - acc[key] = matchingConfigValue.transform(value as string[]); - } else { - acc[key] = value; - } - return acc; - }, - {} as Record - ); - - const AdvancedConfig: AdvancedConfigFinal = { - pruneFreq: advancedSettings.pruneFreq * 60 * 60 * 24, - indexingStart: convertStringToDateTime(indexingStart), - refreshFreq: advancedSettings.refreshFreq * 60, - }; - - // google sites-specific handling - if (connector == "google_site") { - const response = await submitGoogleSite( - selectedFiles, - formValues?.base_url, - setPopup, - AdvancedConfig, - name - ); - if (response) { - setTimeout(() => { - window.open("/admin/indexing/status", "_self"); - }, 1000); - } - return; - } - - // file-specific handling - if (connector == "file" && selectedFiles.length > 0) { - const response = await submitFiles( - selectedFiles, - setPopup, - setSelectedFiles, - name, - AdvancedConfig, - isPublic, - groups - ); - if (response) { - setTimeout(() => { - window.open("/admin/indexing/status", "_self"); - }, 1000); - } - return; - } - - const { message, isSuccess, response } = await submitConnector( - { - connector_specific_config: transformedConnectorSpecificConfig, - input_type: connector == "web" ? "load_state" : "poll", // single case - name: name, - source: connector, - refresh_freq: refreshFreq * 60 || null, - prune_freq: pruneFreq * 60 * 60 * 24 || null, - indexing_start: convertStringToDateTime(indexingStart), - is_public: isPublic, - groups: groups, - }, - undefined, - credentialActivated ? false : true, - isPublic - ); - // If no credential - if (!credentialActivated) { - if (isSuccess) { - setPopup({ - message: "Connector created! Redirecting to connector home page", - type: "success", - }); - setTimeout(() => { - window.open("/admin/indexing/status", "_self"); - }, 1000); - } else { - setPopup({ message: message, type: "error" }); - } - } - - // Without credential - if (credentialActivated && isSuccess && response) { - const credential = - currentCredential || liveGDriveCredential || liveGmailCredential; - const linkCredentialResponse = await linkCredential( - response.id, - credential?.id!, - name, - isPublic, - groups - ); - if (linkCredentialResponse.ok) { - setPopup({ - message: "Connector created! Redirecting to connector home page", - type: "success", - }); - setTimeout(() => { - window.open("/admin/indexing/status", "_self"); - }, 1000); - } else { - const errorData = await linkCredentialResponse.json(); - setPopup({ - message: errorData.message, - type: "error", - }); - } - } else if (isSuccess) { - setPopup({ - message: - "Credential created succsfully! Redirecting to connector home page", - type: "success", - }); - } else { - setPopup({ message: message, type: "error" }); - } - }; - const displayName = getSourceDisplayName(connector) || connector; if (!credentials || !editableCredentials) { return <>; } + // Credential handler functions const refresh = () => { mutate(buildSimilarCredentialInfoURL(connector)); }; + const onDeleteCredential = async (credential: Credential) => { const response = await deleteCredential(credential.id, true); if (response.ok) { @@ -333,285 +139,256 @@ export default function AddConnector({ refresh(); }; - const validationSchema = Yup.object().shape({ - name: Yup.string().required("Connector Name is required"), - ...configuration.values.reduce( - (acc, field) => { - let schema: any = - field.type === "list" - ? Yup.array().of(Yup.string()) - : field.type === "checkbox" - ? Yup.boolean() - : Yup.string(); - - if (!field.optional) { - schema = schema.required(`${field.label} is required`); - } - acc[field.name] = schema; - return acc; - }, - {} as Record - ), - }); - - const advancedValidationSchema = Yup.object().shape({ - indexingStart: Yup.string().nullable(), - pruneFreq: Yup.number().min(0, "Prune frequency must be non-negative"), - refreshFreq: Yup.number().min(0, "Refresh frequency must be non-negative"), - }); - - const isFormSubmittable = (values: any) => { - return ( - values.name.trim() !== "" && - Object.keys(values).every((key) => { - const field = configuration.values.find((f) => f.name === key); - return field?.optional || values[key] !== ""; - }) - ); + const onSuccess = () => { + setPopup({ + message: "Connector created! Redirecting to connector home page", + type: "success", + }); + setTimeout(() => { + window.open("/admin/indexing/status", "_self"); + }, 1000); }; return ( -
- {popup} -
- -
- - } - title={displayName} - /> - - {formStep == 0 && - (connector == "google_drive" ? ( - <> - - Select a credential - - -
- -
- - ) : connector == "gmail" ? ( - <> - - Select a credential - - -
- + { + console.log(" Iam submiing the connector"); + const { + name, + groups, + is_public: isPublic, + pruneFreq, + indexingStart, + refreshFreq, + ...connector_specific_config + } = values; + + // Apply transforms from connectors.ts configuration + const transformedConnectorSpecificConfig = Object.entries( + connector_specific_config + ).reduce( + (acc, [key, value]) => { + const matchingConfigValue = configuration.values.find( + (configValue) => configValue.name === key + ); + if ( + matchingConfigValue && + "transform" in matchingConfigValue && + matchingConfigValue.transform + ) { + acc[key] = matchingConfigValue.transform(value as string[]); + } else { + acc[key] = value; + } + return acc; + }, + {} as Record + ); + + // Apply advanced configuration-specific transforms. + const advancedConfiguration: any = { + pruneFreq: pruneFreq * 60 * 60 * 24, + indexingStart: convertStringToDateTime(indexingStart), + refreshFreq: refreshFreq * 60, + }; + + // Google sites-specific handling + if (connector == "google_sites") { + const response = await submitGoogleSite( + selectedFiles, + values?.base_url, + setPopup, + advancedConfiguration.refreshFreq, + advancedConfiguration.pruneFreq, + advancedConfiguration.indexingStart, + name + ); + if (response) { + onSuccess(); + } + return; + } + + // File-specific handling + if (connector == "file" && selectedFiles.length > 0) { + const response = await submitFiles( + selectedFiles, + setPopup, + setSelectedFiles, + name, + isPublic, + groups + ); + if (response) { + onSuccess(); + } + return; + } + + const { message, isSuccess, response } = await submitConnector( + { + connector_specific_config: transformedConnectorSpecificConfig, + input_type: connector == "web" ? "load_state" : "poll", // single case + name: name, + source: connector, + refresh_freq: advancedConfiguration.refreshFreq || null, + prune_freq: advancedConfiguration.pruneFreq || null, + indexing_start: advancedConfiguration.indexingStart || null, + is_public: isPublic, + groups: groups, + }, + undefined, + credentialActivated ? false : true, + isPublic + ); + // If no credential + if (!credentialActivated) { + if (isSuccess) { + onSuccess(); + } else { + setPopup({ message: message, type: "error" }); + } + } + + // Without credential + if (credentialActivated && isSuccess && response) { + const credential = + currentCredential || liveGDriveCredential || liveGmailCredential; + const linkCredentialResponse = await linkCredential( + response.id, + credential?.id!, + name, + isPublic, + groups + ); + if (linkCredentialResponse.ok) { + onSuccess(); + } else { + const errorData = await linkCredentialResponse.json(); + setPopup({ + message: errorData.message, + type: "error", + }); + } + } else if (isSuccess) { + onSuccess(); + } else { + setPopup({ message: message, type: "error" }); + } + return; + }} + > + {(formikProps) => { + return ( +
+ {popup} + +
+
- - ) : ( - <> - - Select a credential - - {!createConnectorToggle && ( - - )} - - {!(connector == "google_drive") && createConnectorToggle && ( - setCreateConnectorToggle(false)} - > + + } + title={displayName} + /> + + {formStep == 0 && ( + + Select a credential + + {connector == "google_drive" ? ( + + ) : connector == "gmail" ? ( + + ) : ( <> - - Create a {getSourceDisplayName(connector)} credential - - setCreateConnectorToggle(false)} /> - - - )} - -
- -
- - ))} - - {formStep == 1 && ( - <> - - { - // Can be utilized for logging purposes - }} - > - {(formikProps) => { - setFormValues(formikProps.values); - handleFormStatusChange( - formikProps.isValid && isFormSubmittable(formikProps.values) - ); - setAllowAdvanced( - formikProps.isValid && isFormSubmittable(formikProps.values) - ); - - return ( -
- - {isPaidEnterpriseFeaturesEnabled && ( - <> - - - )} -
- ); - }} -
-
-
- {!noCredentials ? ( - - ) : ( -
- )} - - - {!(connector == "file") && ( -
- -
- )} -
- - )} - - {formStep === 2 && ( - <> - - {}} - > - {(formikProps) => { - setAdvancedSettings(formikProps.values); - - return ( - <> - -
+ {!createConnectorToggle && ( -
+ )} + + {/* NOTE: connector will never be google_drive, since the ternary above will + prevent that, but still keeping this here for safety in case the above changes. */} + {(connector as ValidSources) !== "google_drive" && + createConnectorToggle && ( + setCreateConnectorToggle(false)} + > + <> + + Create a {getSourceDisplayName(connector)}{" "} + credential + + setCreateConnectorToggle(false)} + /> + + + )} - ); - }} -
-
-
- - + )} + + )} + + {formStep == 1 && ( + + + + + + )} + + {formStep === 2 && ( + + + + )} + +
- - )} -
+ ); + }} + ); } diff --git a/web/src/app/admin/connectors/[connector]/ConnectorWrapper.tsx b/web/src/app/admin/connectors/[connector]/ConnectorWrapper.tsx index 345ace085bc..c038cdbb4b2 100644 --- a/web/src/app/admin/connectors/[connector]/ConnectorWrapper.tsx +++ b/web/src/app/admin/connectors/[connector]/ConnectorWrapper.tsx @@ -1,6 +1,6 @@ "use client"; -import { ValidSources } from "@/lib/types"; +import { ConfigurableSources, ValidSources } from "@/lib/types"; import AddConnector from "./AddConnectorPage"; import { FormProvider } from "@/components/context/FormContext"; import Sidebar from "./Sidebar"; @@ -8,7 +8,11 @@ import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Button } from "@tremor/react"; import { isValidSource } from "@/lib/sources"; -export default function ConnectorWrapper({ connector }: { connector: string }) { +export default function ConnectorWrapper({ + connector, +}: { + connector: ConfigurableSources; +}) { return (
@@ -28,7 +32,7 @@ export default function ConnectorWrapper({ connector }: { connector: string }) {
) : ( - + )}
diff --git a/web/src/app/admin/connectors/[connector]/NavigationRow.tsx b/web/src/app/admin/connectors/[connector]/NavigationRow.tsx new file mode 100644 index 00000000000..933e4c9d06f --- /dev/null +++ b/web/src/app/admin/connectors/[connector]/NavigationRow.tsx @@ -0,0 +1,91 @@ +import { useFormContext } from "@/components/context/FormContext"; +import { ArrowLeft, ArrowRight } from "@phosphor-icons/react"; +import { FiPlus } from "react-icons/fi"; + +const NavigationRow = ({ + noAdvanced, + noCredentials, + activatedCredential, + onSubmit, + isValid, +}: { + isValid: boolean; + onSubmit: () => void; + noAdvanced: boolean; + noCredentials: boolean; + activatedCredential: boolean; +}) => { + const { formStep, prevFormStep, nextFormStep } = useFormContext(); + const SquareNavigationButton = ({ + onClick, + disabled, + className, + children, + }: { + onClick: () => void; + disabled?: boolean; + className: string; + children: React.ReactNode; + }) => ( + + ); + + return ( +
+
+ {formStep > 0 && !noCredentials && ( + + + Previous + + )} +
+ +
+ {(formStep > 0 || noCredentials) && ( + + Create Connector + + + )} +
+ +
+ {formStep === 0 && ( + + Continue + + + )} + {noAdvanced && formStep === 1 && ( + + Advanced + + + )} +
+
+ ); +}; +export default NavigationRow; diff --git a/web/src/app/admin/connectors/[connector]/Sidebar.tsx b/web/src/app/admin/connectors/[connector]/Sidebar.tsx index 97275843e0c..288916f99d0 100644 --- a/web/src/app/admin/connectors/[connector]/Sidebar.tsx +++ b/web/src/app/admin/connectors/[connector]/Sidebar.tsx @@ -25,9 +25,10 @@ export default function Sidebar() { ]; return ( -
+
; + return ( + + ); } diff --git a/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx b/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx index 470ab8d2a77..0f50a7043b8 100644 --- a/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/Advanced.tsx @@ -1,68 +1,47 @@ -import React, { Dispatch, forwardRef, SetStateAction } from "react"; -import { Formik, Form, FormikProps } from "formik"; -import * as Yup from "yup"; +import React from "react"; import NumberInput from "./ConnectorInput/NumberInput"; import { TextFormField } from "@/components/admin/connectors/Field"; +import { TrashIcon } from "@/components/icons/icons"; -interface AdvancedFormPageProps { - formikProps: FormikProps<{ - indexingStart: string | null; - pruneFreq: number; - refreshFreq: number; - }>; -} +const AdvancedFormPage = () => { + return ( +
+

+ Advanced Configuration +

-const AdvancedFormPage = forwardRef, AdvancedFormPageProps>( - ({ formikProps }, ref) => { - const { indexingStart, refreshFreq, pruneFreq } = formikProps.values; + - return ( -
-

- Advanced Configuration -

+ - -
- -
- -
- -
- -
- -
- + +
+
- ); - } -); +
+ ); +}; -AdvancedFormPage.displayName = "AdvancedFormPage"; export default AdvancedFormPage; diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx index 5a9f5041b5d..b7fcb49cf1e 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx @@ -1,15 +1,13 @@ import { SubLabel } from "@/components/admin/connectors/Field"; -import { Field } from "formik"; +import { Field, useFormikContext } from "formik"; export default function NumberInput({ label, - value, optional, description, name, showNeverIfZero, }: { - value?: number; label: string; name: string; optional?: boolean; @@ -28,7 +26,6 @@ export default function NumberInput({ type="number" name={name} min="-1" - value={value === 0 && showNeverIfZero ? "Never" : value} className={`mt-2 block w-full px-3 py-2 bg-white border border-gray-300 rounded-md text-sm shadow-sm placeholder-gray-400 diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/SelectInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/SelectInput.tsx index e01a02dc323..a7c14deec20 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/SelectInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/SelectInput.tsx @@ -1,40 +1,42 @@ import CredentialSubText from "@/components/credentials/CredentialFields"; -import { ListOption, SelectOption } from "@/lib/connectors/connectors"; +import { + ListOption, + SelectOption, + StringWithDescription, +} from "@/lib/connectors/connectors"; import { Field } from "formik"; export default function SelectInput({ - field, - value, - onChange, + name, + optional, + description, + options, + label, }: { - field: SelectOption; - value: any; - onChange?: (e: Event) => void; + name: string; + optional?: boolean; + description?: string; + options: StringWithDescription[]; + label?: string; }) { return ( <> - {field.description && ( - {field.description} - )} + {description && {description}} - {field.options?.map((option: any) => ( + {options?.map((option: any) => ( diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 507b976f9a8..cd7e90f6167 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -1,28 +1,9 @@ -import React, { - ChangeEvent, - Dispatch, - FC, - SetStateAction, - useEffect, - useState, -} from "react"; -import { Formik, Form, Field, FieldArray, FormikProps } from "formik"; -import * as Yup from "yup"; -import { FaPlus } from "react-icons/fa"; -import { useUserGroups } from "@/lib/hooks"; -import { UserGroup, User, UserRole } from "@/lib/types"; -import { Divider } from "@tremor/react"; +import React, { Dispatch, FC, SetStateAction } from "react"; import CredentialSubText, { AdminBooleanFormField, } from "@/components/credentials/CredentialFields"; -import { TrashIcon } from "@/components/icons/icons"; import { FileUpload } from "@/components/admin/connectors/FileUpload"; import { ConnectionConfiguration } from "@/lib/connectors/connectors"; -import { useFormContext } from "@/components/context/FormContext"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; -import { Text } from "@tremor/react"; -import { getCurrentUser } from "@/lib/user"; -import { FiUsers } from "react-icons/fi"; import SelectInput from "./ConnectorInput/SelectInput"; import NumberInput from "./ConnectorInput/NumberInput"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -63,6 +44,7 @@ const DynamicConnectionForm: FC = ({
{field.type == "file" ? ( @@ -78,15 +60,19 @@ const DynamicConnectionForm: FC = ({ ) : field.type === "list" ? ( ) : field.type === "select" ? ( - + ) : field.type === "number" ? ( ) : field.type === "checkbox" ? ( -

+

When using a Google Drive Service Account, you can either have Danswer act as the service account itself OR you can specify an account for the service account to impersonate. @@ -356,70 +358,59 @@ export const DriveOAuthSection = ({ the documents you want to index with the service account.

- - { - formikHelpers.setSubmitting(true); - - const response = await fetch( - "/api/manage/admin/connector/google-drive/service-account-credential", - { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - google_drive_delegated_user: - values.google_drive_delegated_user, - }), - } - ); - - if (response.ok) { - setPopup({ - message: "Successfully created service account credential", - type: "success", - }); - } else { - const errorMsg = await response.text(); - setPopup({ - message: `Failed to create service account credential - ${errorMsg}`, - type: "error", - }); + { + formikHelpers.setSubmitting(true); + const response = await fetch( + "/api/manage/admin/connector/google-drive/service-account-credential", + { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + google_drive_delegated_user: + values.google_drive_delegated_user, + }), } - refreshCredentials(); - }} - > - {({ isSubmitting }) => ( -
- -
- -
- - )} -
-
+ ); + + if (response.ok) { + setPopup({ + message: "Successfully created service account credential", + type: "success", + }); + } else { + const errorMsg = await response.text(); + setPopup({ + message: `Failed to create service account credential - ${errorMsg}`, + type: "error", + }); + } + refreshCredentials(); + }} + > + {({ isSubmitting }) => ( +
+ +
+ + Create Credential + +
+ + )} +
); } diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index 4494e4b22ee..247b64e61b4 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -8,8 +8,6 @@ import { ErrorCallout } from "@/components/ErrorCallout"; import { LoadingAnimation } from "@/components/Loading"; import { usePopup } from "@/components/admin/connectors/Popup"; import { ConnectorIndexingStatus } from "@/lib/types"; -import { getCurrentUser } from "@/lib/user"; -import { User, UserRole } from "@/lib/types"; import { usePublicCredentials } from "@/lib/hooks"; import { Title } from "@tremor/react"; import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential"; @@ -109,6 +107,7 @@ const GDriveMain = ({}: {}) => { | undefined = credentialsData.find( (credential) => credential.credential_json?.google_drive_service_account_key ); + const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus< GoogleDriveConfig, GoogleDriveCredentialJson diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts index d847efe89d1..7535eec35bb 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts @@ -2,14 +2,12 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { createConnector, runConnector } from "@/lib/connector"; import { createCredential, linkCredential } from "@/lib/credential"; import { FileConfig } from "@/lib/connectors/connectors"; -import { AdvancedConfigFinal } from "../../AddConnectorPage"; export const submitFiles = async ( selectedFiles: File[], setPopup: (popup: PopupSpec) => void, setSelectedFiles: (files: File[]) => void, name: string, - advancedConfig: AdvancedConfigFinal, isPublic: boolean, groups?: number[] ) => { diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts index f1689e8fcdf..11d7f46ecea 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts @@ -2,13 +2,14 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { createConnector, runConnector } from "@/lib/connector"; import { linkCredential } from "@/lib/credential"; import { GoogleSitesConfig } from "@/lib/connectors/connectors"; -import { AdvancedConfigFinal } from "../../AddConnectorPage"; export const submitGoogleSite = async ( selectedFiles: File[], base_url: any, setPopup: (popup: PopupSpec) => void, - advancedConfig: AdvancedConfigFinal, + refreshFreq: number, + pruneFreq: number, + indexingStart: Date, name?: string ) => { const uploadCreateAndTriggerConnector = async () => { @@ -41,9 +42,9 @@ export const submitGoogleSite = async ( base_url: base_url, zip_path: filePaths[0], }, - refresh_freq: advancedConfig.refreshFreq, - prune_freq: advancedConfig.pruneFreq, - indexing_start: advancedConfig.indexingStart, + refresh_freq: refreshFreq, + prune_freq: pruneFreq, + indexing_start: indexingStart, }); if (connectorErrorMsg || !connector) { setPopup({ diff --git a/web/src/app/admin/documents/explorer/Explorer.tsx b/web/src/app/admin/documents/explorer/Explorer.tsx index a773c222484..c1722b01edf 100644 --- a/web/src/app/admin/documents/explorer/Explorer.tsx +++ b/web/src/app/admin/documents/explorer/Explorer.tsx @@ -211,7 +211,7 @@ export function Explorer({ )} {!query && (
- Search for a document above to modify it's boost or hide it from + Search for a document above to modify its boost or hide it from searches.
)} diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index 814af4e2863..2778103e345 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -125,14 +125,12 @@ export const DocumentSetCreationForm = ({ placeholder="Describe what the document set represents" autoCompleteDisabled={true} /> - {isPaidEnterpriseFeaturesEnabled && - userGroups && - userGroups.length > 0 && ( - - )} + {isPaidEnterpriseFeaturesEnabled && ( + + )} diff --git a/web/src/app/admin/documents/sets/page.tsx b/web/src/app/admin/documents/sets/page.tsx index 718b81ab0b6..41104f9c343 100644 --- a/web/src/app/admin/documents/sets/page.tsx +++ b/web/src/app/admin/documents/sets/page.tsx @@ -67,10 +67,11 @@ const EditRow = ({
)}
{ if (documentSet.is_up_to_date) { router.push(`/admin/documents/sets/${documentSet.id}`); @@ -87,8 +88,8 @@ const EditRow = ({ } }} > - - {documentSet.name} + + {documentSet.name}
); diff --git a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx index 1b9fffda428..2b4394c56b5 100644 --- a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx +++ b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx @@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal"; import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal"; import { AlreadyPickedModal } from "./modals/AlreadyPickedModal"; import { ModelOption } from "../../../components/embedding/ModelSelector"; -import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants"; +import { + EMBEDDING_MODELS_ADMIN_URL, + EMBEDDING_PROVIDERS_ADMIN_URL, +} from "../configuration/llm/constants"; export interface EmbeddingDetails { - api_key: string; + api_key?: string; + api_url?: string; custom_config: any; provider_type: EmbeddingProvider; } @@ -77,12 +81,20 @@ export function EmbeddingModelSelection({ const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = useState(false); + const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); + const { data: embeddingModelDetails } = useSWR( + EMBEDDING_MODELS_ADMIN_URL, + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const { data: embeddingProviderDetails } = useSWR( EMBEDDING_PROVIDERS_ADMIN_URL, - errorHandlingFetcher + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds ); const { data: connectors } = useSWR[]>( @@ -175,6 +187,7 @@ export function EmbeddingModelSelection({ {showTentativeProvider && ( { setShowTentativeProvider(showUnconfiguredProvider); @@ -189,8 +202,10 @@ export function EmbeddingModelSelection({ }} /> )} + {changeCredentialsProvider && ( { clientsideRemoveProvider(changeCredentialsProvider); @@ -277,9 +292,10 @@ export function EmbeddingModelSelection({ {modelTab == "cloud" && ( { const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false); + const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] = + useState(false); return ( -
-

- Post-processing -

-
- {originalRerankingDetails.rerank_model_name && ( - - )} -
- -
+ () + .nullable() + .oneOf(Object.values(RerankerProvider)) + .optional(), + api_key: Yup.string().nullable(), + num_rerank: Yup.number().min(1, "Must be at least 1"), + rerank_api_url: Yup.string() + .url("Must be a valid URL") + .matches(/^https?:\/\//, "URL must start with http:// or https://") + .nullable(), + })} + onSubmit={async (_, { setSubmitting }) => { + setSubmitting(false); + }} + enableReinitialize={true} + > + {({ values, setFieldValue, resetForm }) => { + const resetRerankingValues = () => { + setRerankingDetails(originalRerankingDetails); + resetForm(); + }; -
- -
-
+ return ( +
+

+ Post-processing +

+
+ {originalRerankingDetails.rerank_model_name && ( + + )} +
+ +
- () - .nullable() - .oneOf(Object.values(RerankerProvider)) - .optional(), - rerank_api_key: Yup.string().nullable(), - num_rerank: Yup.number().min(1, "Must be at least 1"), - })} - onSubmit={async (_, { setSubmitting }) => { - setSubmitting(false); - }} - enableReinitialize={true} - > - {({ values, setFieldValue }) => ( -
-
- {(modelTab - ? rerankingModels.filter( - (model) => model.cloud == (modelTab == "cloud") - ) - : rerankingModels.filter( - (modelCard) => - modelCard.modelName == - originalRerankingDetails.rerank_model_name - ) - ).map((card) => { - const isSelected = - values.rerank_provider_type === card.rerank_provider_type && - values.rerank_model_name === card.modelName; - return ( -
{ - if (card.rerank_provider_type) { - setIsApiKeyModalOpen(true); - } - setRerankingDetails({ - ...values, - rerank_provider_type: card.rerank_provider_type!, - rerank_model_name: card.modelName, - }); - setFieldValue( - "rerank_provider_type", - card.rerank_provider_type - ); - setFieldValue("rerank_model_name", card.modelName); - }} +
+ +
+ {values.rerank_model_name && ( +
+ +
+ )} +
+ + +
+ {(modelTab + ? rerankingModels.filter( + (model) => model.cloud == (modelTab == "cloud") + ) + : rerankingModels.filter( + (modelCard) => + (modelCard.modelName == + originalRerankingDetails.rerank_model_name && + modelCard.rerank_provider_type == + originalRerankingDetails.rerank_provider_type) || + (modelCard.rerank_provider_type == + RerankerProvider.LITELLM && + originalRerankingDetails.rerank_provider_type == + RerankerProvider.LITELLM) + ) + ).map((card) => { + const isSelected = + values.rerank_provider_type === + card.rerank_provider_type && + (card.modelName == null || + values.rerank_model_name === card.modelName); + + return ( +
{ + if ( + card.rerank_provider_type == RerankerProvider.COHERE + ) { + setIsApiKeyModalOpen(true); + } else if ( + card.rerank_provider_type == + RerankerProvider.LITELLM + ) { + setShowLiteLLMConfigurationModal(true); + } + + if (!isSelected) { + setRerankingDetails({ + ...values, + rerank_provider_type: card.rerank_provider_type!, + rerank_model_name: card.modelName || null, + rerank_api_key: null, + rerank_api_url: null, + }); + setFieldValue( + "rerank_provider_type", + card.rerank_provider_type + ); + setFieldValue("rerank_model_name", card.modelName); + } + }} + > +
+
+ {card.rerank_provider_type === + RerankerProvider.LITELLM ? ( + + ) : card.rerank_provider_type === + RerankerProvider.COHERE ? ( + + ) : ( + + )} +

+ {card.displayName} +

+
+ {card.link && ( + e.stopPropagation()} + className="text-blue-500 hover:text-blue-700 transition-colors duration-200" + > + + )} -

- {card.displayName} -

- {card.link && ( - e.stopPropagation()} - className="text-blue-500 hover:text-blue-700 transition-colors duration-200" - > - - - )} +

+ {card.description} +

+
+ {card.cloud ? "Cloud-based" : "Self-hosted"} +
-

- {card.description} -

-
- {card.cloud ? "Cloud-based" : "Self-hosted"} + ); + })} +
+ + {showLiteLLMConfigurationModal && ( + { + resetForm(); + setShowLiteLLMConfigurationModal(false); + }} + width="w-[800px]" + title="API Key Configuration" + > +
+ ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_url: value, + }); + setFieldValue("rerank_api_url", value); + }} + type="text" + label="LiteLLM Proxy URL" + name="rerank_api_url" + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_api_key: value, + }); + setFieldValue("rerank_api_key", value); + }} + type="password" + label="LiteLLM Proxy Key" + name="rerank_api_key" + optional + /> + + ) => { + const value = e.target.value; + setRerankingDetails({ + ...values, + rerank_model_name: value, + }); + setFieldValue("rerank_model_name", value); + }} + label="LiteLLM Model Name" + name="rerank_model_name" + optional + /> + +
+
- ); - })} -
+ + )} - {isApiKeyModalOpen && ( - { - Object.keys(originalRerankingDetails).forEach((key) => { - setFieldValue( - key, - originalRerankingDetails[key as keyof RerankingDetails] - ); - }); - - setIsApiKeyModalOpen(false); - }} - width="w-[800px]" - title="API Key Configuration" - > -
- ) => { - const value = e.target.value; - setRerankingDetails({ ...values, rerank_api_key: value }); - setFieldValue("rerank_api_key", value); - }} - type="password" - label="Cohere API Key" - name="rerank_api_key" - /> -
- - + type="password" + label="Cohere API Key" + name="rerank_api_key" + /> +
+ +
-
-
- )} - - )} - -
+ + )} + +
+ ); + }} + ); } ); diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index c3dec13e6cc..3b92fdd759f 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -1,16 +1,19 @@ import { EmbeddingProvider } from "@/components/embedding/interfaces"; -import { NonNullChain } from "typescript"; +// This is a slightly differnte interface than used in the backend +// but is always used in conjunction with `AdvancedSearchConfiguration` export interface RerankingDetails { rerank_model_name: string | null; rerank_provider_type: RerankerProvider | null; rerank_api_key: string | null; - num_rerank: number; + rerank_api_url: string | null; } export enum RerankerProvider { COHERE = "cohere", + LITELLM = "litellm", } + export interface AdvancedSearchConfiguration { model_name: string; model_dim: number; @@ -21,24 +24,19 @@ export interface AdvancedSearchConfiguration { multipass_indexing: boolean; multilingual_expansion: string[]; disable_rerank_for_streaming: boolean; + api_url: string | null; + num_rerank: number; } -export interface SavedSearchSettings extends RerankingDetails { - model_name: string; - model_dim: number; - normalize: boolean; - query_prefix: string; - passage_prefix: string; - index_name: string | null; - multipass_indexing: boolean; - multilingual_expansion: string[]; - disable_rerank_for_streaming: boolean; +export interface SavedSearchSettings + extends RerankingDetails, + AdvancedSearchConfiguration { provider_type: EmbeddingProvider | null; } export interface RerankingModel { rerank_provider_type: RerankerProvider | null; - modelName: string; + modelName?: string; displayName: string; description: string; link: string; @@ -46,6 +44,13 @@ export interface RerankingModel { } export const rerankingModels: RerankingModel[] = [ + { + rerank_provider_type: RerankerProvider.LITELLM, + cloud: true, + displayName: "LiteLLM", + description: "Host your own reranker or router with LiteLLM proxy", + link: "https://docs.litellm.ai/docs/proxy", + }, { rerank_provider_type: null, cloud: false, diff --git a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx index c2f3923e5cd..636aa562474 100644 --- a/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx +++ b/web/src/app/admin/embeddings/modals/ChangeCredentialsModal.tsx @@ -15,14 +15,19 @@ export function ChangeCredentialsModal({ onCancel, onDeleted, useFileUpload, + isProxy = false, }: { provider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; onDeleted: () => void; useFileUpload: boolean; + isProxy?: boolean; }) { const [apiKey, setApiKey] = useState(""); + const [apiUrl, setApiUrl] = useState(""); + const [modelName, setModelName] = useState(""); + const [testError, setTestError] = useState(""); const [fileName, setFileName] = useState(""); const fileInputRef = useRef(null); @@ -74,7 +79,7 @@ export function ChangeCredentialsModal({ try { const response = await fetch( - `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`, + `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`, { method: "DELETE", } @@ -99,13 +104,18 @@ export function ChangeCredentialsModal({ const handleSubmit = async () => { setTestError(""); + const normalizedProviderType = provider.provider_type + .toLowerCase() + .split(" ")[0]; try { const testResponse = await fetch("/api/admin/embedding/test-embedding", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ - provider_type: provider.provider_type.toLowerCase().split(" ")[0], + provider_type: normalizedProviderType, api_key: apiKey, + api_url: apiUrl, + model_name: modelName, }), }); @@ -118,8 +128,9 @@ export function ChangeCredentialsModal({ method: "PUT", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ - provider_type: provider.provider_type.toLowerCase().split(" ")[0], + provider_type: normalizedProviderType, api_key: apiKey, + api_url: apiUrl, is_default_provider: false, is_configured: true, }), @@ -128,7 +139,8 @@ export function ChangeCredentialsModal({ if (!updateResponse.ok) { const errorData = await updateResponse.json(); throw new Error( - errorData.detail || "Failed to update provider- check your API key" + errorData.detail || + `Failed to update provider- check your ${isProxy ? "API URL" : "API key"}` ); } @@ -144,26 +156,20 @@ export function ChangeCredentialsModal({ -
- - Want to swap out your key? - - - Visit API - - -
+ <> +

+ You can modify your configuration by providing a new API key + {isProxy ? " or API URL." : "."} +

+ +
+ {useFileUpload ? ( <> - + )} -
- {testError && ( - - {testError} - - )} + {isProxy && ( + <> + + + setApiUrl(e.target.value)} + placeholder="Paste your API URL here" + /> + + {deletionError && ( + + {deletionError} + + )} + +
+ +

+ Since you are using a liteLLM proxy, we'll need a model + name to test the connection with. +

+
+ setModelName(e.target.value)} + placeholder="Paste your API URL here" + /> + + {deletionError && ( + + {deletionError} + + )} + + )} + + {testError && ( + + {testError} + + )} -
+ + + + + You can also delete your configuration. + + + This is only possible if you have already switched to a different + embedding type! + + + + {deletionError && ( + + {deletionError} + + )}
- - - - You can also delete your key. - - - This is only possible if you have already switched to a different - embedding type! - - - - {deletionError && ( - - {deletionError} - - )} -
+ ); } diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index 4b2ad9c51fc..1b54b2f123b 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -1,6 +1,6 @@ import React, { useRef, useState } from "react"; import { Text, Button, Callout } from "@tremor/react"; -import { Formik, Form, Field } from "formik"; +import { Formik, Form } from "formik"; import * as Yup from "yup"; import { Label, TextFormField } from "@/components/admin/connectors/Field"; import { LoadingAnimation } from "@/components/Loading"; @@ -13,11 +13,13 @@ export function ProviderCreationModal({ onConfirm, onCancel, existingProvider, + isProxy, }: { selectedProvider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; existingProvider?: CloudEmbeddingProvider; + isProxy?: boolean; }) { const useFileUpload = selectedProvider.provider_type == "Google"; @@ -29,17 +31,27 @@ export function ProviderCreationModal({ provider_type: existingProvider?.provider_type || selectedProvider.provider_type, api_key: existingProvider?.api_key || "", + api_url: existingProvider?.api_url || "", custom_config: existingProvider?.custom_config ? Object.entries(existingProvider.custom_config) : [], model_id: 0, + model_name: null, }; const validationSchema = Yup.object({ provider_type: Yup.string().required("Provider type is required"), - api_key: useFileUpload + api_key: isProxy ? Yup.string() - : Yup.string().required("API Key is required"), + : useFileUpload + ? Yup.string() + : Yup.string().required("API Key is required"), + model_name: isProxy + ? Yup.string().required("Model name is required") + : Yup.string().nullable(), + api_url: isProxy + ? Yup.string().required("API URL is required") + : Yup.string(), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), }); @@ -87,6 +99,8 @@ export function ProviderCreationModal({ body: JSON.stringify({ provider_type: values.provider_type.toLowerCase().split(" ")[0], api_key: values.api_key, + api_url: values.api_url, + model_name: values.model_name, }), } ); @@ -144,14 +158,7 @@ export function ProviderCreationModal({ validationSchema={validationSchema} onSubmit={handleSubmit} > - {({ - values, - errors, - touched, - isSubmitting, - handleSubmit, - setFieldValue, - }) => ( + {({ isSubmitting, handleSubmit, setFieldValue }) => (
You are setting the credentials for this provider. To access @@ -169,11 +176,28 @@ export function ProviderCreationModal({ target="_blank" href={selectedProvider.apiLink} > - API KEY + {isProxy ? "API URL" : "API KEY"} -
+
+ {isProxy && ( + <> + + + + )} + {useFileUpload ? ( <> @@ -189,7 +213,7 @@ export function ProviderCreationModal({ ) : ( diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index 89d885a1368..c965bdfabf8 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -4,7 +4,7 @@ import * as Yup from "yup"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces"; +import { AdvancedSearchConfiguration } from "../interfaces"; import { BooleanFormField } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; @@ -14,159 +14,102 @@ interface AdvancedEmbeddingFormPageProps { value: any ) => void; advancedEmbeddingDetails: AdvancedSearchConfiguration; - numRerank: number; } const AdvancedEmbeddingFormPage = forwardRef< FormikProps, AdvancedEmbeddingFormPageProps ->( - ( - { updateAdvancedEmbeddingDetails, advancedEmbeddingDetails, numRerank }, - ref - ) => { - return ( -
-

- Advanced Configuration -

- { - setSubmitting(false); - }} - enableReinitialize={true} - > - {({ values, setFieldValue }) => ( - - - {({ push, remove }) => ( -
- - - List of languages for multilingual expansion. Leave empty - for no additional expansion. - - {values.multilingual_expansion.map( - (_: any, index: number) => ( -
- (({ updateAdvancedEmbeddingDetails, advancedEmbeddingDetails }, ref) => { + return ( +
+

+ Advanced Configuration +

+ { + setSubmitting(false); + }} + validate={(values) => { + // Call updateAdvancedEmbeddingDetails for each changed field + Object.entries(values).forEach(([key, value]) => { + updateAdvancedEmbeddingDetails( + key as keyof AdvancedSearchConfiguration, + value + ); + }); + }} + enableReinitialize={true} + > + {({ values }) => ( + + + {({ push, remove }) => ( +
+ {values.multilingual_expansion.map( + (_: any, index: number) => ( +
+ - ) => { - const newValue = [ - ...values.multilingual_expansion, - ]; - newValue[index] = e.target.value; - setFieldValue("multilingual_expansion", newValue); - updateAdvancedEmbeddingDetails( - "multilingual_expansion", - newValue - ); - }} - value={values.multilingual_expansion[index]} - /> - - -
- ) - )} - - +
+ ) + )} + -
- )} - + > + + Add Language + +
+ )} + - ) => { - const checked = e.target.checked; - updateAdvancedEmbeddingDetails("multipass_indexing", checked); - setFieldValue("multipass_indexing", checked); - }} - label="Multipass Indexing" - name="multipassIndexing" - /> - ) => { - const checked = e.target.checked; - updateAdvancedEmbeddingDetails( - "disable_rerank_for_streaming", - checked - ); - setFieldValue("disable_rerank_for_streaming", checked); - }} - label="Disable Rerank for Streaming" - name="disableRerankForStreaming" - /> - - - )} - -
- ); - } -); + + + + + )} +
+
+ ); +}); +export default AdvancedEmbeddingFormPage; AdvancedEmbeddingFormPage.displayName = "AdvancedEmbeddingFormPage"; -export default AdvancedEmbeddingFormPage; diff --git a/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx b/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx index a7a7a1553a5..a6c71530f24 100644 --- a/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx +++ b/web/src/app/admin/embeddings/pages/CloudEmbeddingPage.tsx @@ -1,6 +1,6 @@ "use client"; -import { Text, Title } from "@tremor/react"; +import { Button, Card, Text, Title } from "@tremor/react"; import { CloudEmbeddingProvider, @@ -8,15 +8,22 @@ import { AVAILABLE_CLOUD_PROVIDERS, CloudEmbeddingProviderFull, EmbeddingModelDescriptor, + EmbeddingProvider, + LITELLM_CLOUD_PROVIDER, } from "../../../../components/embedding/interfaces"; import { EmbeddingDetails } from "../EmbeddingModelSelectionForm"; -import { FiExternalLink, FiInfo } from "react-icons/fi"; +import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi"; import { HoverPopup } from "@/components/HoverPopup"; -import { Dispatch, SetStateAction } from "react"; +import { Dispatch, SetStateAction, useEffect, useState } from "react"; +import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm"; +import { deleteSearchSettings } from "./utils"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; export default function CloudEmbeddingPage({ currentModel, embeddingProviderDetails, + embeddingModelDetails, newEnabledProviders, newUnenabledProviders, setShowTentativeProvider, @@ -30,6 +37,7 @@ export default function CloudEmbeddingPage({ currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel; setAlreadySelectedModel: Dispatch>; newUnenabledProviders: string[]; + embeddingModelDetails?: CloudEmbeddingModel[]; embeddingProviderDetails?: EmbeddingDetails[]; newEnabledProviders: string[]; setShowTentativeProvider: React.Dispatch< @@ -61,6 +69,17 @@ export default function CloudEmbeddingPage({ ))!), }) ); + const [liteLLMProvider, setLiteLLMProvider] = useState< + EmbeddingDetails | undefined + >(undefined); + + useEffect(() => { + const foundProvider = embeddingProviderDetails?.find( + (provider) => + provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase() + ); + setLiteLLMProvider(foundProvider); + }, [embeddingProviderDetails]); return (
@@ -122,6 +141,127 @@ export default function CloudEmbeddingPage({
))} + + + Alternatively, you can use a self-hosted model using the LiteLLM + proxy. This allows you to leverage various LLM providers through a + unified interface that you control.{" "} + + Learn more about LiteLLM + + + +
+
+ {LITELLM_CLOUD_PROVIDER.icon({ size: 40 })} +

+ {LITELLM_CLOUD_PROVIDER.provider_type}{" "} + {LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" && + "(recommended)"} +

+ + } + popupContent={ +
+
+ {LITELLM_CLOUD_PROVIDER.description} +
+
+ } + style="dark" + /> +
+
+ {!liteLLMProvider ? ( + + ) : ( + + )} + + {!liteLLMProvider && ( + +
+ + API URL Required + + + Before you can add models, you need to provide an API URL + for your LiteLLM proxy. Click the "Provide API + URL" button above to set up your LiteLLM configuration. + +
+ + + Once configured, you'll be able to add and manage + your LiteLLM models here. + +
+
+
+ )} + {liteLLMProvider && ( + <> +
+ {embeddingModelDetails + ?.filter( + (model) => + model.provider_type === + EmbeddingProvider.LITELLM.toLowerCase() + ) + .map((model) => ( + + ))} +
+ + + + + + )} +
+
); @@ -146,7 +286,32 @@ export function CloudModelCard({ React.SetStateAction >; }) { - const enabled = model.model_name === currentModel.model_name; + const { popup, setPopup } = usePopup(); + const [showDeleteModel, setShowDeleteModel] = useState(false); + const enabled = + model.model_name === currentModel.model_name && + model.provider_type?.toLowerCase() == + currentModel.provider_type?.toLowerCase(); + + const deleteModel = async () => { + if (!model.id) { + setPopup({ message: "Model cannot be deleted", type: "error" }); + return; + } + + const response = await deleteSearchSettings(model.id); + + if (response.ok) { + setPopup({ message: "Model deleted successfully", type: "success" }); + setShowDeleteModel(false); + } else { + setPopup({ + message: + "Failed to delete model. Ensure you are not attempting to delete a curently active model.", + type: "error", + }); + } + }; return (
+ {popup} + {showDeleteModel && ( + deleteModel()} + onClose={() => setShowDeleteModel(false)} + /> + )} +

{model.model_name}

- e.stopPropagation()} - className="text-blue-500 hover:text-blue-700 transition-colors duration-200" - > - - +
+ {model.provider_type == EmbeddingProvider.LITELLM.toLowerCase() && ( + + )} + e.stopPropagation()} + className="text-blue-500 hover:text-blue-700 transition-colors duration-200" + > + + +

{model.description}

-
- ${model.pricePerMillion}/M tokens -
+ {model?.provider_type?.toLowerCase() != + EmbeddingProvider.LITELLM.toLowerCase() && ( +
+ ${model.pricePerMillion}/M tokens +
+ )}
+
-
- setSearchTerm(e.target.value)} - className="ml-2 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" - /> - - -
{sortedSources - .filter((source) => source != "not_applicable") + .filter( + (source) => + source != "not_applicable" && source != "ingestion_api" + ) .map((source, ind) => { const sourceMatches = source .toLowerCase() @@ -479,7 +467,7 @@ export function CCPairIndexingStatusTable({ if (sourceMatches || matchingConnectors.length > 0) { return ( -
+
- - Name - - - Last Indexed - - - Activity - + Name + Last Indexed + Activity {isPaidEnterpriseFeaturesEnabled && ( - - Permissions - + Permissions )} - - Total Docs - - - Last Status - - + Total Docs + Last Status + {(sourceMatches ? groupedStatuses[source] diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 247bfd09d83..2c315f320ec 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -15,16 +15,27 @@ export interface Notification { first_shown: string; } +export interface NavigationItem { + link: string; + icon: string; + title: string; +} + export interface EnterpriseSettings { application_name: string | null; use_custom_logo: boolean; use_custom_logotype: boolean; + // custom navigation + custom_nav_items: NavigationItem[]; + // custom Chat components custom_lower_disclaimer_content: string | null; custom_header_content: string | null; + two_lines_for_chat_header: boolean | null; custom_popup_header: string | null; custom_popup_content: string | null; + enable_consent_screen: boolean | null; } export interface CombinedSettings { diff --git a/web/src/app/api/[...path]/route.ts b/web/src/app/api/[...path]/route.ts new file mode 100644 index 00000000000..550ebaf6d1f --- /dev/null +++ b/web/src/app/api/[...path]/route.ts @@ -0,0 +1,116 @@ +import { INTERNAL_URL } from "@/lib/constants"; +import { NextRequest, NextResponse } from "next/server"; + +/* NextJS is annoying and makes use use a separate function for +each request type >:( */ + +export async function GET( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function POST( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function PUT( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function PATCH( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function DELETE( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function HEAD( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +export async function OPTIONS( + request: NextRequest, + { params }: { params: { path: string[] } } +) { + return handleRequest(request, params.path); +} + +async function handleRequest(request: NextRequest, path: string[]) { + if (process.env.NODE_ENV !== "development") { + return NextResponse.json( + { + message: + "This API is only available in development mode. In production, something else (e.g. nginx) should handle this.", + }, + { status: 404 } + ); + } + + try { + const backendUrl = new URL(`${INTERNAL_URL}/${path.join("/")}`); + + // Get the URL parameters from the request + const urlParams = new URLSearchParams(request.url.split("?")[1]); + + // Append the URL parameters to the backend URL + urlParams.forEach((value, key) => { + backendUrl.searchParams.append(key, value); + }); + + const response = await fetch(backendUrl, { + method: request.method, + headers: request.headers, + body: request.body, + // @ts-ignore + duplex: "half", + }); + + // Check if the response is a stream + if ( + response.headers.get("Transfer-Encoding") === "chunked" || + response.headers.get("Content-Type")?.includes("stream") + ) { + // If it's a stream, create a TransformStream to pass the data through + const { readable, writable } = new TransformStream(); + response.body?.pipeTo(writable); + + return new NextResponse(readable, { + status: response.status, + headers: response.headers, + }); + } else { + return new NextResponse(response.body, { + status: response.status, + headers: response.headers, + }); + } + } catch (error: unknown) { + console.error("Proxy error:", error); + return NextResponse.json( + { + message: "Proxy error", + error: + error instanceof Error ? error.message : "An unknown error occurred", + }, + { status: 500 } + ); + } +} diff --git a/web/src/app/assistants/ToolsDisplay.tsx b/web/src/app/assistants/ToolsDisplay.tsx index 10c25b640c9..2be7670c0ee 100644 --- a/web/src/app/assistants/ToolsDisplay.tsx +++ b/web/src/app/assistants/ToolsDisplay.tsx @@ -71,7 +71,7 @@ export function AssistantTools({ w-fit flex items-center - ${hovered ? "bg-background-300" : list ? "bg-background-125" : "bg-background-100"}`} + ${list ? "bg-background-125" : "bg-background-100"}`} >
@@ -91,7 +91,7 @@ export function AssistantTools({ border-border w-fit flex - ${hovered ? "bg-background-300" : list ? "bg-background-125" : "bg-background-100"}`} + ${list ? "bg-background-125" : "bg-background-100"}`} >
( - ( - - ), - p: ({ node, ...props }) => ( -

- ), - }} - remarkPlugins={[remarkGfm]} - > - {settings.enterpriseSettings?.custom_header_content} - - ); return (

@@ -90,7 +66,7 @@ export function ChatBanner() { className="absolute top-0 left-0 invisible w-full" >
diff --git a/web/src/app/chat/ChatIntro.tsx b/web/src/app/chat/ChatIntro.tsx index 27353aa340f..3703655d7f9 100644 --- a/web/src/app/chat/ChatIntro.tsx +++ b/web/src/app/chat/ChatIntro.tsx @@ -1,30 +1,9 @@ -import { getSourceMetadataForSources, listSourceMetadata } from "@/lib/sources"; +import { getSourceMetadataForSources } from "@/lib/sources"; import { ValidSources } from "@/lib/types"; -import Image from "next/image"; import { Persona } from "../admin/assistants/interfaces"; import { Divider } from "@tremor/react"; -import { FiBookmark, FiCpu, FiInfo, FiX, FiZoomIn } from "react-icons/fi"; +import { FiBookmark, FiInfo } from "react-icons/fi"; import { HoverPopup } from "@/components/HoverPopup"; -import { Modal } from "@/components/Modal"; -import { useState } from "react"; -import { Logo } from "@/components/Logo"; - -const MAX_PERSONAS_TO_DISPLAY = 4; - -function HelperItemDisplay({ - title, - description, -}: { - title: string; - description: string; -}) { - return ( -
-
{title}
-
{description}
-
- ); -} export function ChatIntro({ availableSources, diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 3ee22d4d74f..40e4368a5af 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -4,6 +4,7 @@ import { useRouter, useSearchParams } from "next/navigation"; import { BackendChatSession, BackendMessage, + BUFFER_COUNT, ChatFileType, ChatSession, ChatSessionSharedStatus, @@ -48,6 +49,7 @@ import { SetStateAction, useContext, useEffect, + useLayoutEffect, useRef, useState, } from "react"; @@ -65,7 +67,12 @@ import { FiArrowDown } from "react-icons/fi"; import { ChatIntro } from "./ChatIntro"; import { AIMessage, HumanMessage } from "./message/Messages"; import { StarterMessage } from "./StarterMessage"; -import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces"; +import { + AnswerPiecePacket, + DanswerDocument, + StreamStopInfo, + StreamStopReason, +} from "@/lib/search/interfaces"; import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; @@ -86,7 +93,6 @@ import FunctionalHeader from "@/components/chat_search/Header"; import { useSidebarVisibility } from "@/components/chat_search/hooks"; import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; import FixedLogo from "./shared_chat_search/FixedLogo"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { SetDefaultModelModal } from "./modal/SetDefaultModelModal"; import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal"; import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown"; @@ -94,6 +100,7 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; +import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -102,12 +109,10 @@ const SYSTEM_MESSAGE_ID = -3; export function ChatPage({ toggle, documentSidebarInitialWidth, - defaultSelectedAssistantId, toggledSidebar, }: { toggle: (toggled?: boolean) => void; documentSidebarInitialWidth?: number; - defaultSelectedAssistantId?: number; toggledSidebar: boolean; }) { const router = useRouter(); @@ -122,9 +127,14 @@ export function ChatPage({ folders, openedFolders, userInputPrompts, + defaultAssistantId, + shouldShowWelcomeModal, + refreshChatSessions, } = useChatContext(); - const { user, refreshUser } = useUser(); + const [showApiKeyModal, setShowApiKeyModal] = useState(true); + + const { user, refreshUser, isLoadingUser } = useUser(); // chat session const existingChatIdRaw = searchParams.get("chatId"); @@ -133,6 +143,7 @@ export function ChatPage({ const existingChatSessionId = existingChatIdRaw ? parseInt(existingChatIdRaw) : null; + const selectedChatSession = chatSessions.find( (chatSession) => chatSession.id === existingChatSessionId ); @@ -157,9 +168,9 @@ export function ChatPage({ ? availableAssistants.find( (assistant) => assistant.id === existingChatSessionAssistantId ) - : defaultSelectedAssistantId !== undefined + : defaultAssistantId !== undefined ? availableAssistants.find( - (assistant) => assistant.id === defaultSelectedAssistantId + (assistant) => assistant.id === defaultAssistantId ) : undefined ); @@ -201,6 +212,7 @@ export function ChatPage({ selectedAssistant || filteredAssistants[0] || availableAssistants[0]; + useEffect(() => { if (!loadedIdSessionRef.current && !currentPersonaId) { return; @@ -249,6 +261,7 @@ export function ChatPage({ updateChatState("input", currentSession); }; + // this is for "@"ing assistants // this is used to track which assistant is being used to generate the current message @@ -273,6 +286,7 @@ export function ChatPage({ ); const [isReady, setIsReady] = useState(false); + useEffect(() => { Prism.highlightAll(); setIsReady(true); @@ -319,8 +333,8 @@ export function ChatPage({ async function initialSessionFetch() { if (existingChatSessionId === null) { setIsFetchingChatMessages(false); - if (defaultSelectedAssistantId !== undefined) { - setSelectedAssistantFromId(defaultSelectedAssistantId); + if (defaultAssistantId !== undefined) { + setSelectedAssistantFromId(defaultAssistantId); } else { setSelectedAssistant(undefined); } @@ -337,6 +351,10 @@ export function ChatPage({ } return; } + const shouldScrollToBottom = + visibleRange.get(existingChatSessionId) === undefined || + visibleRange.get(existingChatSessionId)?.end == 0; + clearSelectedDocuments(); setIsFetchingChatMessages(true); const response = await fetch( @@ -371,10 +389,16 @@ export function ChatPage({ // go to bottom. If initial load, then do a scroll, // otherwise just appear at the bottom - if (!hasPerformedInitialScroll) { - clientScrollToBottom(); - } else if (isChatSessionSwitch) { - clientScrollToBottom(true); + if (shouldScrollToBottom) { + scrollInitialized.current = false; + } + + if (shouldScrollToBottom) { + if (!hasPerformedInitialScroll) { + clientScrollToBottom(); + } else if (isChatSessionSwitch) { + clientScrollToBottom(true); + } } setIsFetchingChatMessages(false); @@ -393,7 +417,7 @@ export function ChatPage({ // force re-name if the chat session doesn't have one if (!chatSession.description) { await nameChatSession(existingChatSessionId, seededMessage); - router.refresh(); // need to refresh to update name on sidebar + refreshChatSessions(); } } } @@ -521,17 +545,6 @@ export function ChatPage({ new Map([[chatSessionIdRef.current, "input"]]) ); - const [scrollHeight, setScrollHeight] = useState>( - new Map([[chatSessionIdRef.current, 0]]) - ); - const currentScrollHeight = () => { - return scrollHeight.get(currentSessionId()); - }; - - const retrieveCurrentScrollHeight = (): number | null => { - return scrollHeight.get(currentSessionId()) || null; - }; - const [regenerationState, setRegenerationState] = useState< Map >(new Map([[null, null]])); @@ -623,6 +636,24 @@ export function ChatPage({ const currentRegenerationState = (): RegenerationState | null => { return regenerationState.get(currentSessionId()) || null; }; + const [canContinue, setCanContinue] = useState>( + new Map([[null, false]]) + ); + + const updateCanContinue = (newState: boolean, sessionId?: number | null) => { + setCanContinue((prevState) => { + const newCanContinueState = new Map(prevState); + newCanContinueState.set( + sessionId !== undefined ? sessionId : currentSessionId(), + newState + ); + return newCanContinueState; + }); + }; + + const currentCanContinue = (): boolean => { + return canContinue.get(currentSessionId()) || false; + }; const currentSessionChatState = currentChatState(); const currentSessionRegenerationState = currentRegenerationState(); @@ -649,12 +680,10 @@ export function ChatPage({ useEffect(() => { if (messageHistory.length === 0 && chatSessionIdRef.current === null) { setSelectedAssistant( - filteredAssistants.find( - (persona) => persona.id === defaultSelectedAssistantId - ) + filteredAssistants.find((persona) => persona.id === defaultAssistantId) ); } - }, [defaultSelectedAssistantId]); + }, [defaultAssistantId]); const [ selectedDocuments, @@ -751,17 +780,47 @@ export function ChatPage({ const clientScrollToBottom = (fast?: boolean) => { setTimeout(() => { - if (fast) { - endDivRef.current?.scrollIntoView(); + if (!endDivRef.current || !scrollableDivRef.current) { + return; + } + + const rect = endDivRef.current.getBoundingClientRect(); + const isVisible = rect.top >= 0 && rect.bottom <= window.innerHeight; + + if (isVisible) return; + + // Check if all messages are currently rendered + if (currentVisibleRange.end < messageHistory.length) { + // Update visible range to include the last messages + updateCurrentVisibleRange({ + start: Math.max( + 0, + messageHistory.length - + (currentVisibleRange.end - currentVisibleRange.start) + ), + end: messageHistory.length, + mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId, + }); + + // Wait for the state update and re-render before scrolling + setTimeout(() => { + endDivRef.current?.scrollIntoView({ + behavior: fast ? "auto" : "smooth", + }); + setHasPerformedInitialScroll(true); + }, 0); } else { - endDivRef.current?.scrollIntoView({ behavior: "smooth" }); + // If all messages are already rendered, scroll immediately + endDivRef.current.scrollIntoView({ + behavior: fast ? "auto" : "smooth", + }); + setHasPerformedInitialScroll(true); } - setHasPerformedInitialScroll(true); }, 50); }; const distance = 500; // distance that should "engage" the scroll - const debounce = 100; // time for debouncing + const debounceNumber = 100; // time for debouncing const [hasPerformedInitialScroll, setHasPerformedInitialScroll] = useState( existingChatSessionId === null @@ -863,6 +922,13 @@ export function ChatPage({ } }; + const continueGenerating = () => { + onSubmit({ + messageOverride: + "Continue Generating (pick up exactly where you left off)", + }); + }; + const onSubmit = async ({ messageIdToResend, messageOverride, @@ -883,6 +949,7 @@ export function ChatPage({ regenerationRequest?: RegenerationRequest | null; } = {}) => { let frozenSessionId = currentSessionId(); + updateCanContinue(false, frozenSessionId); if (currentChatState() != "input") { setPopup({ @@ -892,13 +959,6 @@ export function ChatPage({ return; } - updateRegenerationState( - regenerationRequest - ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } - : null - ); - - updateChatState("loading"); setAlternativeGeneratingAssistant(alternativeAssistantOverride); clientScrollToBottom(); @@ -929,6 +989,11 @@ export function ChatPage({ (message) => message.messageId === messageIdToResend ); + updateRegenerationState( + regenerationRequest + ? { regenerating: true, finalMessageIndex: messageIdToResend || 0 } + : null + ); const messageMap = currentMessageMap(completeMessageDetail); const messageToResendParent = messageToResend?.parentMessageId !== null && @@ -955,6 +1020,9 @@ export function ChatPage({ } setSubmittedMessage(currMessage); + + updateChatState("loading"); + const currMessageHistory = messageToResendIndex !== null ? messageHistory.slice(0, messageToResendIndex) @@ -977,6 +1045,8 @@ export function ChatPage({ let messageUpdates: Message[] | null = null; let answer = ""; + + let stopReason: StreamStopReason | null = null; let query: string | null = null; let retrievalType: RetrievalType = selectedDocuments.length > 0 @@ -1067,6 +1137,12 @@ export function ChatPage({ console.error( "First packet should contain message response info " ); + if (Object.hasOwn(packet, "error")) { + const error = (packet as StreamingError).error; + setLoadingError(error); + updateChatState("input"); + return; + } continue; } @@ -1173,6 +1249,11 @@ export function ChatPage({ stackTrace = (packet as StreamingError).stack_trace; } else if (Object.hasOwn(packet, "message_id")) { finalMessage = packet as BackendMessage; + } else if (Object.hasOwn(packet, "stop_reason")) { + const stop_reason = (packet as StreamStopInfo).stop_reason; + if (stop_reason === StreamStopReason.CONTEXT_LENGTH) { + updateCanContinue(true, frozenSessionId); + } } // on initial message send, we insert a dummy system message @@ -1236,6 +1317,7 @@ export function ChatPage({ alternateAssistantID: alternativeAssistant?.id, stackTrace: stackTrace, overridden_model: finalMessage?.overridden_model, + stopReason: stopReason, }, ]); } @@ -1280,6 +1362,7 @@ export function ChatPage({ if (!searchParamBasedChatSessionName) { await new Promise((resolve) => setTimeout(resolve, 200)); await nameChatSession(currChatSessionId, currMessage); + refreshChatSessions(); } // NOTE: don't switch pages if the user has navigated away from the chat @@ -1415,6 +1498,7 @@ export function ChatPage({ // Used to maintain a "time out" for history sidebar so our existing refs can have time to process change const [untoggled, setUntoggled] = useState(false); + const [loadingError, setLoadingError] = useState(null); const explicitlyUntoggle = () => { setShowDocSidebar(false); @@ -1457,9 +1541,129 @@ export function ChatPage({ scrollDist, endDivRef, distance, - debounce, + debounceNumber, + }); + + // Virtualization + Scrolling related effects and functions + const scrollInitialized = useRef(false); + interface VisibleRange { + start: number; + end: number; + mostVisibleMessageId: number | null; + } + + const [visibleRange, setVisibleRange] = useState< + Map + >(() => { + const initialRange: VisibleRange = { + start: 0, + end: BUFFER_COUNT, + mostVisibleMessageId: null, + }; + return new Map([[chatSessionIdRef.current, initialRange]]); }); + // Function used to update current visible range. Only method for updating `visibleRange` state. + const updateCurrentVisibleRange = ( + newRange: VisibleRange, + forceUpdate?: boolean + ) => { + if ( + scrollInitialized.current && + visibleRange.get(loadedIdSessionRef.current) == undefined && + !forceUpdate + ) { + return; + } + + setVisibleRange((prevState) => { + const newState = new Map(prevState); + newState.set(loadedIdSessionRef.current, newRange); + return newState; + }); + }; + + // Set first value for visibleRange state on page load / refresh. + const initializeVisibleRange = () => { + const upToDatemessageHistory = buildLatestMessageChain( + currentMessageMap(completeMessageDetail) + ); + + if (!scrollInitialized.current && upToDatemessageHistory.length > 0) { + const newEnd = Math.max(upToDatemessageHistory.length, BUFFER_COUNT); + const newStart = Math.max(0, newEnd - BUFFER_COUNT); + const newMostVisibleMessageId = + upToDatemessageHistory[newEnd - 1]?.messageId; + + updateCurrentVisibleRange( + { + start: newStart, + end: newEnd, + mostVisibleMessageId: newMostVisibleMessageId, + }, + true + ); + scrollInitialized.current = true; + } + }; + + const updateVisibleRangeBasedOnScroll = () => { + if (!scrollInitialized.current) return; + const scrollableDiv = scrollableDivRef.current; + if (!scrollableDiv) return; + + const viewportHeight = scrollableDiv.clientHeight; + let mostVisibleMessageIndex = -1; + + messageHistory.forEach((message, index) => { + const messageElement = document.getElementById( + `message-${message.messageId}` + ); + if (messageElement) { + const rect = messageElement.getBoundingClientRect(); + const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0; + if (isVisible && index > mostVisibleMessageIndex) { + mostVisibleMessageIndex = index; + } + } + }); + + if (mostVisibleMessageIndex !== -1) { + const startIndex = Math.max(0, mostVisibleMessageIndex - BUFFER_COUNT); + const endIndex = Math.min( + messageHistory.length, + mostVisibleMessageIndex + BUFFER_COUNT + 1 + ); + + updateCurrentVisibleRange({ + start: startIndex, + end: endIndex, + mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId, + }); + } + }; + + useEffect(() => { + initializeVisibleRange(); + }, [router, messageHistory, chatSessionIdRef.current]); + + useLayoutEffect(() => { + const handleScroll = () => { + updateVisibleRangeBasedOnScroll(); + }; + scrollableDivRef.current?.addEventListener("scroll", handleScroll); + + return () => { + scrollableDivRef.current?.removeEventListener("scroll", handleScroll); + }; + }, [messageHistory]); + + const currentVisibleRange = visibleRange.get(currentSessionId()) || { + start: 0, + end: 0, + mostVisibleMessageId: null, + }; + useEffect(() => { const includes = checkAnyAssistantHasSearch( messageHistory, @@ -1518,7 +1722,6 @@ export function ChatPage({ setDocumentSelection((documentSelection) => !documentSelection); setShowDocSidebar(false); }; - const secondsUntilExpiration = getSecondsUntilExpiration(user); interface RegenerationRequest { messageId: number; @@ -1538,7 +1741,12 @@ export function ChatPage({ return ( <> - + + + {showApiKeyModal && !shouldShowWelcomeModal && ( + setShowApiKeyModal(false)} /> + )} + {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. Only used in the EE version of the app. */} {popup} @@ -1580,10 +1788,14 @@ export function ChatPage({ if (response.ok) { setDeletingChatSession(null); // go back to the main page - router.push("/chat"); + if (deletingChatSession.id === chatSessionIdRef.current) { + router.push("/chat"); + } } else { - alert("Failed to delete chat session"); + const responseJson = await response.json(); + setPopup({ message: responseJson.detail, type: "error" }); } + router.refresh(); }} /> )} @@ -1680,7 +1892,9 @@ export function ChatPage({ /> )} - {documentSidebarInitialWidth !== undefined && isReady ? ( + {documentSidebarInitialWidth !== undefined && + isReady && + !isLoadingUser ? ( {({ getRootProps }) => (
@@ -1705,7 +1919,6 @@ export function ChatPage({ className={`h-full w-full relative flex-auto transition-margin duration-300 overflow-x-auto mobile:pb-12 desktop:pb-[100px]`} {...getRootProps()} > - {/* */}
- {messageHistory.map((message, i) => { + {(messageHistory.length < BUFFER_COUNT + ? messageHistory + : messageHistory.slice( + currentVisibleRange.start, + currentVisibleRange.end + ) + ).map((message, fauxIndex) => { + const i = + messageHistory.length < BUFFER_COUNT + ? fauxIndex + : fauxIndex + currentVisibleRange.start; + const messageMap = currentMessageMap( completeMessageDetail ); @@ -1736,17 +1965,21 @@ export function ChatPage({ const parentMessage = message.parentMessageId ? messageMap.get(message.parentMessageId) : null; - if ( - currentSessionRegenerationState?.regenerating && - message.messageId >= - currentSessionRegenerationState?.finalMessageIndex! - ) { - return <>; - } - if (message.type === "user") { + if ( + (currentSessionChatState == "loading" && + i == messageHistory.length - 1) || + (currentSessionRegenerationState?.regenerating && + message.messageId >= + currentSessionRegenerationState?.finalMessageIndex!) + ) { + return <>; + } return ( -
+
messageHistory.length - 1) || + (currentSessionRegenerationState?.regenerating && + message.messageId > + currentSessionRegenerationState?.finalMessageIndex!) ) { return <>; } return (
- )} + {(currentSessionChatState == "loading" || + (loadingError && + !currentSessionRegenerationState?.regenerating && + messageHistory[messageHistory.length - 1] + ?.type != "user")) && ( + + )} {currentSessionChatState == "loading" && (
)} + {loadingError && ( +
+ + {loadingError} +

+ } + /> +
+ )} {currentPersona && currentPersona.starter_messages && currentPersona.starter_messages.length > 0 && @@ -2090,6 +2347,7 @@ export function ChatPage({ )}
)} + {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
@@ -2111,6 +2369,9 @@ export function ChatPage({
)} + setShowApiKeyModal(true) + } chatState={currentSessionChatState} stopGenerating={stopGenerating} openModelSettings={() => setSettingsToggled(true)} @@ -2141,7 +2402,6 @@ export function ChatPage({
{ setCompletedFlow( @@ -20,16 +20,26 @@ export function ChatPopup() { }); const settings = useContext(SettingsContext); - if (!settings?.enterpriseSettings?.custom_popup_content || completedFlow) { + const enterpriseSettings = settings?.enterpriseSettings; + const isConsentScreen = enterpriseSettings?.enable_consent_screen; + if ( + (!enterpriseSettings?.custom_popup_content && !isConsentScreen) || + completedFlow + ) { return null; } - let popupTitle = settings.enterpriseSettings.custom_popup_header; - if (!popupTitle) { - popupTitle = `Welcome to ${ - settings.enterpriseSettings.application_name || "Danswer" - }!`; - } + const popupTitle = + enterpriseSettings?.custom_popup_header || + (isConsentScreen + ? "Terms of Use" + : `Welcome to ${enterpriseSettings?.application_name || "Danswer"}!`); + + const popupContent = + enterpriseSettings?.custom_popup_content || + (isConsentScreen + ? "By clicking 'I Agree', you acknowledge that you agree to the terms of use of this application and consent to proceed." + : ""); return ( @@ -49,12 +59,26 @@ export function ChatPopup() { }} remarkPlugins={[remarkGfm]} > - {settings.enterpriseSettings.custom_popup_content} + {popupContent} -
+ {showConsentError && ( +

+ You need to agree to the terms to access the application. +

+ )} + +
+ {isConsentScreen && ( + + )}
diff --git a/web/src/app/chat/WrappedChat.tsx b/web/src/app/chat/WrappedChat.tsx index cdb8508dfb0..6b48e442175 100644 --- a/web/src/app/chat/WrappedChat.tsx +++ b/web/src/app/chat/WrappedChat.tsx @@ -3,21 +3,15 @@ import { ChatPage } from "./ChatPage"; import FunctionalWrapper from "./shared_chat_search/FunctionalWrapper"; export default function WrappedChat({ - defaultAssistantId, initiallyToggled, }: { - defaultAssistantId?: number; initiallyToggled: boolean; }) { return ( ( - + )} /> ); diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index b579abefeed..6ea2ce868a5 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -33,12 +33,14 @@ import { Tooltip } from "@/components/tooltip/Tooltip"; import { Hoverable } from "@/components/Hoverable"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { ChatState } from "../types"; +import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ openModelSettings, showDocs, + showConfigureAPIKey, selectedDocuments, message, setMessage, @@ -62,6 +64,7 @@ export function ChatInputBar({ chatSessionId, inputPrompts, }: { + showConfigureAPIKey: () => void; openModelSettings: () => void; chatState: ChatState; stopGenerating: () => void; @@ -111,6 +114,7 @@ export function ChatInputBar({ } } }; + const settings = useContext(SettingsContext); const { llmProviders } = useChatContext(); @@ -338,10 +342,10 @@ export function ChatInputBar({ updateInputPrompt(currentPrompt); }} > -

{currentPrompt.prompt}

-

+

{currentPrompt.prompt}:

+

{currentPrompt.id == selectedAssistant.id && "(default) "} - {currentPrompt.content} + {currentPrompt.content?.trim()}

))} @@ -364,6 +368,9 @@ export function ChatInputBar({
+ + +
; scrollDist: MutableRefObject; endDivRef: RefObject; distance: number; - debounce: number; + debounceNumber: number; mobile?: boolean; }) { const preventScrollInterference = useRef(false); @@ -709,7 +711,7 @@ export async function useScrollonStream({ setTimeout(() => { blockActionRef.current = false; - }, debounce); + }, debounceNumber); } } } diff --git a/web/src/app/chat/message/CodeBlock.tsx b/web/src/app/chat/message/CodeBlock.tsx index 7da83195b43..29f141bcc3f 100644 --- a/web/src/app/chat/message/CodeBlock.tsx +++ b/web/src/app/chat/message/CodeBlock.tsx @@ -50,6 +50,12 @@ export function CodeBlock({ ); codeText = codeText.trim(); + // Find the last occurrence of closing backticks + const lastBackticksIndex = codeText.lastIndexOf("```"); + if (lastBackticksIndex !== -1) { + codeText = codeText.slice(0, lastBackticksIndex + 3); + } + // Remove the language declaration and trailing backticks const codeLines = codeText.split("\n"); if ( diff --git a/web/src/app/chat/message/ContinueMessage.tsx b/web/src/app/chat/message/ContinueMessage.tsx new file mode 100644 index 00000000000..097b3e57e33 --- /dev/null +++ b/web/src/app/chat/message/ContinueMessage.tsx @@ -0,0 +1,37 @@ +import { EmphasizedClickable } from "@/components/BasicClickable"; +import { useEffect, useState } from "react"; +import { FiBook, FiPlayCircle } from "react-icons/fi"; + +export function ContinueGenerating({ + handleContinueGenerating, +}: { + handleContinueGenerating: () => void; +}) { + const [showExplanation, setShowExplanation] = useState(false); + + useEffect(() => { + const timer = setTimeout(() => { + setShowExplanation(true); + }, 1000); + + return () => clearTimeout(timer); + }, []); + + return ( +
+
+ + <> + + Continue Generation + + + {showExplanation && ( +
+ LLM reached its token limit. Click to continue. +
+ )} +
+
+ ); +} diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 09cacd1b9f1..03042304b5f 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -64,7 +64,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import GeneratingImageDisplay from "../tools/GeneratingImageDisplay"; import RegenerateOption from "../RegenerateOption"; import { LlmOverride } from "@/lib/hooks"; -import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; +import { ContinueGenerating } from "./ContinueMessage"; const TOOLS_WITH_CUSTOM_HANDLING = [ SEARCH_TOOL_NAME, @@ -123,6 +123,7 @@ function FileDisplay({ export const AIMessage = ({ regenerate, overriddenModel, + continueGenerating, shared, isActive, toggleDocumentSelection, @@ -150,6 +151,7 @@ export const AIMessage = ({ }: { shared?: boolean; isActive?: boolean; + continueGenerating?: () => void; otherMessagesCanSwitchTo?: number[]; onMessageSelection?: (messageId: number) => void; selectedDocuments?: DanswerDocument[] | null; @@ -283,11 +285,12 @@ export const AIMessage = ({ size="small" assistant={alternativeAssistant || currentPersona} /> +
- {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && ( + {!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? ( <> {query !== undefined && handleShowRetrieved !== undefined && @@ -315,7 +318,8 @@ export const AIMessage = ({
)} - )} + ) : null} + {toolCall && !TOOLS_WITH_CUSTOM_HANDLING.includes( toolCall.tool_name @@ -358,7 +362,7 @@ export const AIMessage = ({ {typeof content === "string" ? ( -
+
+ {(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && + !query && + continueGenerating && ( + + )}
); @@ -706,6 +715,7 @@ export const HumanMessage = ({ // Move the cursor to the end of the text textareaRef.current.selectionStart = textareaRef.current.value.length; textareaRef.current.selectionEnd = textareaRef.current.value.length; + textareaRef.current.style.height = `${textareaRef.current.scrollHeight}px`; } }, [isEditing]); @@ -731,6 +741,7 @@ export const HumanMessage = ({
+
{isEditing ? ( @@ -777,6 +788,7 @@ export const HumanMessage = ({ style={{ scrollbarWidth: "thin" }} onChange={(e) => { setEditedContent(e.target.value); + textareaRef.current!.style.height = "auto"; e.target.style.height = `${e.target.scrollHeight}px`; }} onKeyDown={(e) => { @@ -871,7 +883,7 @@ export const HumanMessage = ({ !isEditing && (!files || files.length === 0) ) && "ml-auto" - } relative flex-none max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`} + } relative flex-none max-w-[70%] mb-auto whitespace-break-spaces rounded-3xl bg-user px-5 py-2.5`} > {content}
diff --git a/web/src/app/chat/message/SkippedSearch.tsx b/web/src/app/chat/message/SkippedSearch.tsx index 62c47b7d96f..b339ac784ab 100644 --- a/web/src/app/chat/message/SkippedSearch.tsx +++ b/web/src/app/chat/message/SkippedSearch.tsx @@ -27,7 +27,7 @@ export function SkippedSearch({ handleForceSearch: () => void; }) { return ( -
+
diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index e391b79dae7..870ad963fca 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -2,10 +2,10 @@ import { redirect } from "next/navigation"; import { unstable_noStore as noStore } from "next/cache"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { WelcomeModal } from "@/components/initialSetup/welcome/WelcomeModalWrapper"; -import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { ChatProvider } from "@/components/context/ChatContext"; import { fetchChatData } from "@/lib/chat/fetchChatData"; import WrappedChat from "./WrappedChat"; +import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; export default async function Page({ searchParams, @@ -23,7 +23,6 @@ export default async function Page({ const { user, chatSessions, - ccPairs, availableSources, documentSets, assistants, @@ -33,9 +32,7 @@ export default async function Page({ toggleSidebar, openedFolders, defaultAssistantId, - finalDocumentSidebarInitialWidth, shouldShowWelcomeModal, - shouldDisplaySourcesIncompleteModal, userInputPrompts, } = data; @@ -43,9 +40,7 @@ export default async function Page({ <> {shouldShowWelcomeModal && } - {!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && ( - - )} + - + + + ); diff --git a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx index df7ddee957f..35256ada98a 100644 --- a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx @@ -46,6 +46,7 @@ export function ChatSessionDisplay({ showDeleteModal?: (chatSession: ChatSession) => void; }) { const router = useRouter(); + const [isHovering, setIsHovering] = useState(false); const [isRenamingChat, setIsRenamingChat] = useState(false); const [isMoreOptionsDropdownOpen, setIsMoreOptionsDropdownOpen] = useState(false); @@ -97,6 +98,11 @@ export function ChatSessionDisplay({ setIsHovering(true)} + onMouseLeave={() => { + setIsMoreOptionsDropdownOpen(false); + setIsHovering(false); + }} onClick={() => { if (settings?.isMobile && closeSidebar) { closeSidebar(); @@ -145,7 +151,7 @@ export function ChatSessionDisplay({

)} - {isSelected && + {isHovering && (isRenamingChat ? (
-
{ - setIsMoreOptionsDropdownOpen( - !isMoreOptionsDropdownOpen - ); - }} - className={"-my-1"} - > - - setIsMoreOptionsDropdownOpen(open) - } - content={ -
- -
- } - popover={ -
- {showShareModal && ( - showShareModal(chatSession)} - /> - )} - setIsRenamingChat(true)} - /> -
- } - requiresContentPadding - sideOffset={6} - triggerMaxWidth - /> -
+ {search ? ( + showDeleteModal && ( +
{ + e.preventDefault(); + showDeleteModal(chatSession); + }} + className={`p-1 -m-1 rounded ml-1`} + > + +
+ ) + ) : ( +
{ + e.preventDefault(); + // e.stopPropagation(); + setIsMoreOptionsDropdownOpen( + !isMoreOptionsDropdownOpen + ); + }} + className="-my-1" + > + + setIsMoreOptionsDropdownOpen(open) + } + content={ +
+ +
+ } + popover={ +
+ {showShareModal && ( + showShareModal(chatSession)} + /> + )} + {!search && ( + setIsRenamingChat(true)} + /> + )} + {showDeleteModal && ( + + showDeleteModal(chatSession) + } + /> + )} +
+ } + requiresContentPadding + sideOffset={6} + triggerMaxWidth + /> +
+ )}
- {showDeleteModal && ( -
showDeleteModal(chatSession)} - className={`hover:bg-black/10 p-1 -m-1 rounded ml-1`} - > - -
- )}
))}
diff --git a/web/src/app/chat/shared_chat_search/FixedLogo.tsx b/web/src/app/chat/shared_chat_search/FixedLogo.tsx index e652d819861..ac5b9afcc3d 100644 --- a/web/src/app/chat/shared_chat_search/FixedLogo.tsx +++ b/web/src/app/chat/shared_chat_search/FixedLogo.tsx @@ -21,11 +21,11 @@ export default function FixedLogo() { } className="fixed cursor-pointer flex z-40 left-2.5 top-2" > -
+
-
+
{enterpriseSettings && enterpriseSettings.application_name ? (
{enterpriseSettings.application_name} diff --git a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx index 4ef22ef4e60..4f8d31d39ee 100644 --- a/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx +++ b/web/src/app/chat/shared_chat_search/FunctionalWrapper.tsx @@ -53,10 +53,15 @@ const ToggleSwitch = () => { onClick={() => handleTabChange("search")} > -

+

Search - {commandSymbol}S -

+
+ + {commandSymbol} + + S +
+
); @@ -122,6 +132,8 @@ export default function FunctionalWrapper({ const settings = combinedSettings?.settings; const chatBannerPresent = combinedSettings?.enterpriseSettings?.custom_header_content; + const twoLines = + combinedSettings?.enterpriseSettings?.two_lines_for_chat_header; const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled); @@ -136,7 +148,7 @@ export default function FunctionalWrapper({ {(!settings || (settings.search_page_enabled && settings.chat_page_enabled)) && (
+

{isUpdate ? "Update a User Group" : "Create a new User Group"} diff --git a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx index 905107dc39b..954d0ab8e49 100644 --- a/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx +++ b/web/src/app/ee/admin/whitelabeling/WhitelabelingForm.tsx @@ -7,6 +7,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import { Form, Formik } from "formik"; import * as Yup from "yup"; import { + BooleanFormField, Label, SubLabel, TextFormField, @@ -55,22 +56,28 @@ export function WhitelabelingForm() { application_name: enterpriseSettings?.application_name || null, use_custom_logo: enterpriseSettings?.use_custom_logo || false, use_custom_logotype: enterpriseSettings?.use_custom_logotype || false, - + two_lines_for_chat_header: + enterpriseSettings?.two_lines_for_chat_header || false, custom_header_content: enterpriseSettings?.custom_header_content || "", custom_popup_header: enterpriseSettings?.custom_popup_header || "", custom_popup_content: enterpriseSettings?.custom_popup_content || "", custom_lower_disclaimer_content: enterpriseSettings?.custom_lower_disclaimer_content || "", + custom_nav_items: enterpriseSettings?.custom_nav_items || [], + enable_consent_screen: + enterpriseSettings?.enable_consent_screen || false, }} validationSchema={Yup.object().shape({ application_name: Yup.string().nullable(), use_custom_logo: Yup.boolean().required(), use_custom_logotype: Yup.boolean().required(), custom_header_content: Yup.string().nullable(), + two_lines_for_chat_header: Yup.boolean().nullable(), custom_popup_header: Yup.string().nullable(), custom_popup_content: Yup.string().nullable(), custom_lower_disclaimer_content: Yup.string().nullable(), + enable_consent_screen: Yup.boolean().nullable(), })} onSubmit={async (values, formikHelpers) => { formikHelpers.setSubmitting(true); @@ -204,28 +211,62 @@ export function WhitelabelingForm() { disabled={isSubmitting} /> + + + + li > p, +ul > li > p { + margin-top: 0; + margin-bottom: 0; + display: inline; + /* Make paragraphs inline to reduce vertical space */ +} diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 7b367662df0..ff060ff72d1 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -3,21 +3,19 @@ import "./globals.css"; import { fetchEnterpriseSettingsSS, fetchSettingsSS, - SettingsError, } from "@/components/settings/lib"; import { CUSTOM_ANALYTICS_ENABLED, + EE_ENABLED, SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED, } from "@/lib/constants"; import { SettingsProvider } from "@/components/settings/SettingsProvider"; import { Metadata } from "next"; -import { buildClientUrl } from "@/lib/utilsSS"; +import { buildClientUrl, fetchSS } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; import Head from "next/head"; import { EnterpriseSettings } from "./admin/settings/interfaces"; -import { redirect } from "next/navigation"; -import { Button, Card } from "@tremor/react"; -import LogoType from "@/components/header/LogoType"; +import { Card } from "@tremor/react"; import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Logo } from "@/components/Logo"; import { UserProvider } from "@/components/user/UserProvider"; @@ -56,6 +54,7 @@ export default async function RootLayout({ children: React.ReactNode; }) { const combinedSettings = await fetchSettingsSS(); + if (!combinedSettings) { // Just display a simple full page error if fetching fails. @@ -76,9 +75,35 @@ export default async function RootLayout({

Error

- Your instance was not configured properly and your - settings could not be loaded. Please contact your admin to fix - this error. + Your EVE AI instance was not configured properly and your + settings could not be loaded. This could be due to an admin + configuration issue or an incomplete setup. +

+

+ If you're an admin, please check{" "} + + our docs + {" "} + to see how to configure EVE AI properly. If you're a user, + please contact your admin to fix this error. +

+

+ For additional support and guidance, you can reach out to our + community on{" "} + + Slack + + .

@@ -109,7 +134,7 @@ export default async function RootLayout({
( )} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 6f6cef8c4f0..e317d271b0d 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -3,10 +3,8 @@ import { getAuthTypeMetadataSS, getCurrentUserSS, } from "@/lib/userSS"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { redirect } from "next/navigation"; import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { fetchSS } from "@/lib/utilsSS"; import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types"; import { cookies } from "next/headers"; @@ -35,6 +33,8 @@ import { DISABLE_LLM_DOC_RELEVANCE, } from "@/lib/constants"; import WrappedSearch from "./WrappedSearch"; +import { SearchProvider } from "@/components/context/SearchContext"; +import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; export default async function Home() { // Disable caching so we always get the up to date connector / document set / persona info @@ -179,18 +179,13 @@ export default async function Home() { const agenticSearchEnabled = agenticSearchToggle ? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false : false; - const secondsUntilExpiration = getSecondsUntilExpiration(user); return ( <> - + {shouldShowWelcomeModal && } - {!shouldShowWelcomeModal && - !shouldDisplayNoSourcesModal && - !shouldDisplaySourcesIncompleteModal && } - {shouldDisplayNoSourcesModal && } {shouldDisplaySourcesIncompleteModal && ( @@ -201,18 +196,27 @@ export default async function Home() { Only used in the EE version of the app. */} - + + + + + ); } diff --git a/web/src/components/IsPublicGroupSelector.tsx b/web/src/components/IsPublicGroupSelector.tsx index 63d47e506a6..6c7aaa17097 100644 --- a/web/src/components/IsPublicGroupSelector.tsx +++ b/web/src/components/IsPublicGroupSelector.tsx @@ -1,3 +1,4 @@ +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import React, { useState, useEffect } from "react"; import { FormikProps, FieldArray, ArrayHelpers, ErrorMessage } from "formik"; import { Text, Divider } from "@tremor/react"; @@ -12,21 +13,28 @@ export type IsPublicGroupSelectorFormType = { groups: number[]; }; +// This should be included for all forms that require groups / public access +// to be set, and access to this / permissioning should be handled within this component itself. export const IsPublicGroupSelector = ({ formikProps, objectName, + publicToWhom = "Users", + removeIndent = false, enforceGroupSelection = true, }: { formikProps: FormikProps; objectName: string; + publicToWhom?: string; + removeIndent?: boolean; enforceGroupSelection?: boolean; }) => { const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); - const { isAdmin, user, isLoadingUser } = useUser(); + const { isAdmin, user, isLoadingUser, isCurator } = useUser(); + const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); const [shouldHideContent, setShouldHideContent] = useState(false); useEffect(() => { - if (user && userGroups) { + if (user && userGroups && isPaidEnterpriseFeaturesEnabled) { const isUserAdmin = user.role === UserRole.ADMIN; if (!isUserAdmin) { formikProps.setFieldValue("is_public", false); @@ -51,6 +59,9 @@ export const IsPublicGroupSelector = ({ if (isLoadingUser || userGroupsIsLoading) { return
Loading...
; } + if (!isPaidEnterpriseFeaturesEnabled) { + return null; + } if (shouldHideContent && enforceGroupSelection) { return ( @@ -72,57 +83,68 @@ export const IsPublicGroupSelector = ({ <> - If set, then this {objectName} will be visible to{" "} - all users. If turned off, then only users who explicitly - have been given access to this {objectName} (e.g. through a User - Group) will have access. + If set, then this {objectName} will be usable by{" "} + All {publicToWhom}. Otherwise, only Admins and{" "} + {publicToWhom} who have explicitly been given access to + this {objectName} (e.g. via a User Group) will have access. } /> )} - {(!formikProps.values.is_public || - !isAdmin || - formikProps.values.groups.length > 0) && ( - <> -
-
- Assign group access for this {objectName} + {(!formikProps.values.is_public || isCurator) && + userGroups && + userGroups?.length > 0 && ( + <> +
+
+ Assign group access for this {objectName} +
-
- - {isAdmin || !enforceGroupSelection ? ( - <> - This {objectName} will be visible/accessible by the groups - selected below - + {userGroupsIsLoading ? ( +
) : ( - <> - Curators must select one or more groups to give access to this{" "} - {objectName} - - )} -
- ( -
- {userGroupsIsLoading ? ( -
+ + {isAdmin || !enforceGroupSelection ? ( + <> + This {objectName} will be visible/accessible by the groups + selected below + ) : ( - userGroups && - userGroups.map((userGroup: UserGroup) => { - const ind = formikProps.values.groups.indexOf(userGroup.id); - let isSelected = ind !== -1; - return ( -
+ Curators must select one or more groups to give access to + this {objectName} + + )} + + )} + ( +
+ {userGroupsIsLoading ? ( +
+ ) : ( + userGroups && + userGroups.map((userGroup: UserGroup) => { + const ind = formikProps.values.groups.indexOf( + userGroup.id + ); + let isSelected = ind !== -1; + return ( +
({ cursor-pointer ${isSelected ? "bg-background-strong" : "hover:bg-hover"} `} - onClick={() => { - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(userGroup.id); - } - }} - > -
- {userGroup.name} + onClick={() => { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(userGroup.id); + } + }} + > +
+ {" "} + {userGroup.name} +
-
- ); - }) - )} -
- )} - /> - - - )} + ); + }) + )} +
+ )} + /> + + + )}
); }; diff --git a/web/src/components/UserDropdown.tsx b/web/src/components/UserDropdown.tsx index 2ca71c577c2..3bc18b5241b 100644 --- a/web/src/components/UserDropdown.tsx +++ b/web/src/components/UserDropdown.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useRef, useContext } from "react"; +import { useState, useRef, useContext, useEffect } from "react"; import { FiLogOut } from "react-icons/fi"; import Link from "next/link"; import { useRouter } from "next/navigation"; @@ -15,6 +15,35 @@ import { UsersIcon, } from "./icons/icons"; import { pageType } from "@/app/chat/sessionSidebar/types"; +import { NavigationItem } from "@/app/admin/settings/interfaces"; +import DynamicFaIcon, { preloadIcons } from "./icons/DynamicFaIcon"; + +interface DropdownOptionProps { + href?: string; + onClick?: () => void; + icon: React.ReactNode; + label: string; +} + +const DropdownOption: React.FC = ({ + href, + onClick, + icon, + label, +}) => { + const content = ( +
+ {icon} + {label} +
+ ); + + return href ? ( + {content} + ) : ( +
{content}
+ ); +}; export function UserDropdown({ user, @@ -28,10 +57,17 @@ export function UserDropdown({ const router = useRouter(); const combinedSettings = useContext(SettingsContext); + const customNavItems: NavigationItem[] = + combinedSettings?.enterpriseSettings?.custom_nav_items || []; + + useEffect(() => { + const iconNames = customNavItems.map((item) => item.icon); + preloadIcons(iconNames); + }, [customNavItems]); + if (!combinedSettings) { return null; } - const settings = combinedSettings.settings; const handleLogout = () => { logout().then((isSuccess) => { @@ -100,44 +136,49 @@ export function UserDropdown({ overscroll-contain `} > - {showAdminPanel && ( - <> - - - Admin Panel - - - )} - {showCuratorPanel && ( - <> - ( + + } + label={item.title} + /> + ))} + + {showAdminPanel ? ( + } + label="Admin Panel" + /> + ) : ( + showCuratorPanel && ( + - - Curator Panel - - + icon={} + label="Curator Panel" + /> + ) )} + {showLogout && + (showCuratorPanel || + showAdminPanel || + customNavItems.length > 0) && ( +
+ )} + {showLogout && ( - <> - {(!(page == "search" || page == "chat") || showAdminPanel) && ( -
- )} -
- - Log out -
- + } + label="Log out" + /> )}
} diff --git a/web/src/components/admin/connectors/AdminSidebar.tsx b/web/src/components/admin/connectors/AdminSidebar.tsx index a319dd0d480..77b00179c5b 100644 --- a/web/src/components/admin/connectors/AdminSidebar.tsx +++ b/web/src/components/admin/connectors/AdminSidebar.tsx @@ -48,13 +48,13 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) { : "/search" } > -
+
-
+
{enterpriseSettings && enterpriseSettings.application_name ? ( -
+
{enterpriseSettings.application_name} @@ -76,9 +76,9 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
- -
- )}
{explanationText && ( @@ -396,17 +393,27 @@ export const BooleanFormField = ({ alignTop, checked, }: BooleanFormFieldProps) => { + const [field, meta, helpers] = useField(name); + const { setValue } = helpers; + + const handleChange = (e: React.ChangeEvent) => { + setValue(e.target.checked); + if (onChange) { + onChange(e); + } + }; + return (