Skip to content

[Platform] Use JSON Path to convert responses #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"symfony/event-dispatcher": "^6.4|^7.0",
"symfony/filesystem": "^6.4|^7.0",
"symfony/finder": "^6.4|^7.0",
"symfony/json-path": "7.3.*",
"symfony/process": "^6.4|^7.0",
"symfony/var-dumper": "^6.4|^7.0"
},
Expand Down
1 change: 1 addition & 0 deletions src/platform/composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"psr/log": "^3.0",
"symfony/clock": "^6.4 || ^7.1",
"symfony/http-client": "^6.4 || ^7.1",
"symfony/json-path": "7.3.*",
"symfony/property-access": "^6.4 || ^7.1",
"symfony/property-info": "^6.4 || ^7.1",
"symfony/serializer": "^6.4 || ^7.1",
Expand Down
3 changes: 1 addition & 2 deletions src/platform/src/Bridge/Albert/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

namespace Symfony\AI\Platform\Bridge\Albert;

use Symfony\AI\Platform\Bridge\OpenAI\Embeddings;
use Symfony\AI\Platform\Bridge\OpenAI\GPT;
use Symfony\AI\Platform\Contract;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
Expand Down Expand Up @@ -40,7 +39,7 @@ public static function create(
new GPTModelClient($httpClient, $apiKey, $baseUrl),
new EmbeddingsModelClient($httpClient, $apiKey, $baseUrl),
],
[new GPT\ResultConverter(), new Embeddings\ResultConverter()],
[new GPT\ResultConverter()],
Contract::create(),
);
}
Expand Down
2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Azure/OpenAI/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public static function create(

return new Platform(
[$GPTModelClient, $embeddingsModelClient, $whisperModelClient],
[new GPT\ResultConverter(), new Embeddings\ResultConverter(), new Whisper\ResultConverter()],
[new GPT\ResultConverter(), new Whisper\ResultConverter()],
$contract ?? Contract::create(new AudioNormalizer()),
);
}
Expand Down
2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Google/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ class Embeddings extends Model
*/
public function __construct(string $name = self::GEMINI_EMBEDDING_EXP_03_07, array $options = [])
{
parent::__construct($name, [Capability::INPUT_MULTIPLE], $options);
parent::__construct($name, [Capability::INPUT_MULTIPLE, Capability::OUTPUT_VECTOR], $options);
}
}
30 changes: 4 additions & 26 deletions src/platform/src/Bridge/Google/Embeddings/ResultConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,15 @@

namespace Symfony\AI\Platform\Bridge\Google\Embeddings;

use Symfony\AI\Platform\Bridge\Google\Embeddings;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Result\RawResultInterface;
use Symfony\AI\Platform\Result\VectorResult;
use Symfony\AI\Platform\ResultConverterInterface;
use Symfony\AI\Platform\Vector\Vector;
use Symfony\AI\Platform\Contract\ResultConverter\VectorResultConverter;

/**
* @author Valtteri R <[email protected]>
*/
final readonly class ResultConverter implements ResultConverterInterface
final readonly class ResultConverter extends VectorResultConverter
{
public function supports(Model $model): bool
public function __construct()
{
return $model instanceof Embeddings;
}

public function convert(RawResultInterface $result, array $options = []): VectorResult
{
$data = $result->getData();

if (!isset($data['embeddings'])) {
throw new RuntimeException('Response does not contain data');
}

return new VectorResult(
...array_map(
static fn (array $item): Vector => new Vector($item['values']),
$data['embeddings'],
),
);
parent::__construct('$.embeddings[*].values');
}
}
2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Mistral/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ public function __construct(
string $name = self::MISTRAL_EMBED,
array $options = [],
) {
parent::__construct($name, [Capability::INPUT_MULTIPLE], $options);
parent::__construct($name, [Capability::INPUT_MULTIPLE, Capability::OUTPUT_VECTOR], $options);
}
}
2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Mistral/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static function create(

return new Platform(
[new Embeddings\ModelClient($httpClient, $apiKey), new Llm\ModelClient($httpClient, $apiKey)],
[new Embeddings\ResultConverter(), new Llm\ResultConverter()],
[new Llm\ResultConverter()],
$contract ?? Contract::create(new ToolNormalizer()),
);
}
Expand Down
3 changes: 2 additions & 1 deletion src/platform/src/Bridge/OpenAI/Embeddings.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\AI\Platform\Bridge\OpenAI;

use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Model;

/**
Expand All @@ -27,6 +28,6 @@ class Embeddings extends Model
*/
public function __construct(string $name = self::TEXT_3_SMALL, array $options = [])
{
parent::__construct($name, [], $options);
parent::__construct($name, [Capability::OUTPUT_VECTOR], $options);
}
}
47 changes: 0 additions & 47 deletions src/platform/src/Bridge/OpenAI/Embeddings/ResultConverter.php

This file was deleted.

1 change: 0 additions & 1 deletion src/platform/src/Bridge/OpenAI/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ public static function create(
],
[
new GPT\ResultConverter(),
new Embeddings\ResultConverter(),
new DallE\ResultConverter(),
new WhisperResponseConverter(),
],
Expand Down
2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Voyage/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ public static function create(
): Platform {
$httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);

return new Platform([new ModelClient($httpClient, $apiKey)], [new ResultConverter()], $contract);
return new Platform([new ModelClient($httpClient, $apiKey)], [], $contract);
}
}
44 changes: 0 additions & 44 deletions src/platform/src/Bridge/Voyage/ResultConverter.php

This file was deleted.

2 changes: 1 addition & 1 deletion src/platform/src/Bridge/Voyage/Voyage.php
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ class Voyage extends Model
*/
public function __construct(string $name = self::V3, array $options = [])
{
parent::__construct($name, [Capability::INPUT_MULTIPLE], $options);
parent::__construct($name, [Capability::INPUT_MULTIPLE, Capability::OUTPUT_VECTOR], $options);
}
}
1 change: 1 addition & 0 deletions src/platform/src/Capability.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum Capability: string
case OUTPUT_STREAMING = 'output-streaming';
case OUTPUT_STRUCTURED = 'output-structured';
case OUTPUT_TEXT = 'output-text';
case OUTPUT_VECTOR = 'output-vector';

// FUNCTIONALITY
case TOOL_CALLING = 'tool-calling';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,44 @@
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Bridge\Mistral\Embeddings;
namespace Symfony\AI\Platform\Contract\ResultConverter;

use Symfony\AI\Platform\Bridge\Mistral\Embeddings;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Result\RawHttpResult;
use Symfony\AI\Platform\Result\RawResultInterface;
use Symfony\AI\Platform\Result\ResultInterface;
use Symfony\AI\Platform\Result\VectorResult;
use Symfony\AI\Platform\ResultConverterInterface;
use Symfony\AI\Platform\Vector\Vector;
use Symfony\Component\JsonPath\JsonCrawler;
use Symfony\Component\JsonPath\JsonPath;

/**
* @author Christopher Hertel <[email protected]>
*/
final readonly class ResultConverter implements ResultConverterInterface
readonly class VectorResultConverter implements ResultConverterInterface
{
public function __construct(
private string|JsonPath $query = '$.data[*].embedding',
) {
}

public function supports(Model $model): bool
{
return $model instanceof Embeddings;
// TODO: THIS IS NOT ENOUGH TO GET ACTIVATED
return $model->supports(Capability::OUTPUT_VECTOR);
}

public function convert(RawResultInterface|RawHttpResult $result, array $options = []): VectorResult
public function convert(RawResultInterface|RawHttpResult $result, array $options = []): ResultInterface
{
$httpResponse = $result->getObject();

if (200 !== $httpResponse->getStatusCode()) {
throw new RuntimeException(\sprintf('Unexpected response code %d: %s', $httpResponse->getStatusCode(), $httpResponse->getContent(false)));
}

$data = $result->getData();
$crawler = new JsonCrawler($result->getObject()->getContent(false));
$vectors = $crawler->find($this->query);

if (!isset($data['data'])) {
throw new RuntimeException('Response does not contain data');
if (empty($vectors)) {
throw new RuntimeException('Response does not contain any vectors');
}

return new VectorResult(
...array_map(
static fn (array $item): Vector => new Vector($item['embedding']),
$data['data']
),
...array_map(static fn (array $vector): Vector => new Vector($vector), $vectors),
);
}
}
6 changes: 5 additions & 1 deletion src/platform/src/Platform.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\AI\Platform;

use Symfony\AI\Platform\Contract\ResultConverter\VectorResultConverter;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Result\RawResultInterface;
use Symfony\AI\Platform\Result\ResultPromise;
Expand Down Expand Up @@ -41,7 +42,10 @@ public function __construct(
) {
$this->contract = $contract ?? Contract::create();
$this->modelClients = $modelClients instanceof \Traversable ? iterator_to_array($modelClients) : $modelClients;
$this->resultConverters = $resultConverters instanceof \Traversable ? iterator_to_array($resultConverters) : $resultConverters;
$this->resultConverters = array_merge(
$resultConverters instanceof \Traversable ? iterator_to_array($resultConverters) : $resultConverters,
[new VectorResultConverter()],
);
}

public function invoke(Model $model, array|string|object $input, array $options = []): ResultPromise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@
#[UsesClass(Vector::class)]
#[UsesClass(VectorResult::class)]
#[UsesClass(Embeddings::class)]
final class ResponseConverterTest extends TestCase
final class ResultConverterTest extends TestCase
{
#[Test]
public function itConvertsAResponseToAVectorResponse(): void
{
$result = $this->createStub(ResponseInterface::class);
$result
->method('toArray')
->willReturn(json_decode($this->getEmbeddingStub(), true));
$response = $this->createStub(ResponseInterface::class);
$response
->method('getContent')
->willReturn($this->getEmbeddingStub());

$vectorResponse = (new ResultConverter())->convert(new RawHttpResult($result));
$vectorResponse = (new ResultConverter())->convert(new RawHttpResult($response));
$convertedContent = $vectorResponse->getContent();

self::assertCount(2, $convertedContent);
Expand Down
Loading
Loading