Skip to content

Commit f1b4368

Browse files
committed
Add REST API endpoints for task management
1 parent 53edf57 commit f1b4368

File tree

13 files changed

+995
-39
lines changed

13 files changed

+995
-39
lines changed

chromadb/api/__init__.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,53 @@ def _delete(
770770
database: str = DEFAULT_DATABASE,
771771
) -> None:
772772
pass
773+
774+
@abstractmethod
775+
def create_task(
776+
self,
777+
task_name: str,
778+
operator_id: str,
779+
input_collection_id: UUID,
780+
output_collection_name: str,
781+
params: Optional[str] = None,
782+
tenant: str = DEFAULT_TENANT,
783+
database: str = DEFAULT_DATABASE,
784+
) -> tuple[bool, str]:
785+
"""Create a recurring task on a collection.
786+
787+
Args:
788+
task_name: Unique name for this task instance
789+
operator_id: Built-in operator identifier
790+
input_collection_id: Source collection that triggers the task
791+
output_collection_name: Target collection where task output is stored
792+
params: Optional JSON string with operator-specific parameters
793+
tenant: The tenant name
794+
database: The database name
795+
796+
Returns:
797+
tuple: (success: bool, task_id: str)
798+
"""
799+
pass
800+
801+
@abstractmethod
802+
def remove_task(
803+
self,
804+
task_name: str,
805+
input_collection_id: UUID,
806+
delete_output: bool = False,
807+
tenant: str = DEFAULT_TENANT,
808+
database: str = DEFAULT_DATABASE,
809+
) -> bool:
810+
"""Delete a task and prevent any further runs.
811+
812+
Args:
813+
task_name: Name of the task to remove
814+
input_collection_id: Id of the input collection the task is registered on
815+
delete_output: Whether to also delete the output collection
816+
tenant: The tenant name
817+
database: The database name
818+
819+
Returns:
820+
bool: True if successful
821+
"""
822+
pass

chromadb/api/fastapi.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,51 @@ def get_max_batch_size(self) -> int:
695695
pre_flight_checks = self.get_pre_flight_checks()
696696
max_batch_size = cast(int, pre_flight_checks.get("max_batch_size", -1))
697697
return max_batch_size
698+
699+
@trace_method("FastAPI.create_task", OpenTelemetryGranularity.ALL)
700+
@override
701+
def create_task(
702+
self,
703+
task_name: str,
704+
operator_id: str,
705+
input_collection_id: UUID,
706+
output_collection_name: str,
707+
params: Optional[str] = None,
708+
tenant: str = DEFAULT_TENANT,
709+
database: str = DEFAULT_DATABASE,
710+
) -> tuple[bool, str]:
711+
"""Register a recurring task on a collection."""
712+
resp_json = self._make_request(
713+
"post",
714+
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/create",
715+
json={
716+
"task_name": task_name,
717+
"operator_id": operator_id,
718+
"input_collection_id": str(input_collection_id),
719+
"output_collection_name": output_collection_name,
720+
"params": params,
721+
},
722+
)
723+
return cast(bool, resp_json["success"]), cast(str, resp_json["task_id"])
724+
725+
@trace_method("FastAPI.remove_task", OpenTelemetryGranularity.ALL)
726+
@override
727+
def remove_task(
728+
self,
729+
task_name: str,
730+
input_collection_id: UUID,
731+
delete_output: bool = False,
732+
tenant: str = DEFAULT_TENANT,
733+
database: str = DEFAULT_DATABASE,
734+
) -> bool:
735+
"""Delete a task and prevent any further runs."""
736+
resp_json = self._make_request(
737+
"post",
738+
f"/tenants/{tenant}/databases/{database}/collections/{input_collection_id}/tasks/delete",
739+
json={
740+
"input_collection_id": str(input_collection_id),
741+
"task_name": task_name,
742+
"delete_output": delete_output,
743+
},
744+
)
745+
return cast(bool, resp_json["success"])

chromadb/api/models/Collection.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,29 +327,29 @@ def search(
327327
from chromadb.execution.expression import (
328328
Search, Key, K, Knn, Val
329329
)
330-
330+
331331
# Note: K is an alias for Key, so K.DOCUMENT == Key.DOCUMENT
332332
search = (Search()
333333
.where((K("category") == "science") & (K("score") > 0.5))
334334
.rank(Knn(query=[0.1, 0.2, 0.3]) * 0.8 + Val(0.5) * 0.2)
335335
.limit(10, offset=0)
336336
.select(K.DOCUMENT, K.SCORE, "title"))
337-
337+
338338
# Direct construction
339339
from chromadb.execution.expression import (
340340
Search, Eq, And, Gt, Knn, Limit, Select, Key
341341
)
342-
342+
343343
search = Search(
344344
where=And([Eq("category", "science"), Gt("score", 0.5)]),
345345
rank=Knn(query=[0.1, 0.2, 0.3]),
346346
limit=Limit(offset=0, limit=10),
347347
select=Select(keys={Key.DOCUMENT, Key.SCORE, "title"})
348348
)
349-
349+
350350
# Single search
351351
result = collection.search(search)
352-
352+
353353
# Multiple searches at once
354354
searches = [
355355
Search().where(K("type") == "article").rank(Knn(query=[0.1, 0.2])),
@@ -490,3 +490,64 @@ def delete(
490490
tenant=self.tenant,
491491
database=self.database,
492492
)
493+
494+
def create_task(
495+
self,
496+
task_name: str,
497+
operator_id: str,
498+
output_collection_name: str,
499+
params: Optional[str] = None,
500+
) -> tuple[bool, str]:
501+
"""Create a recurring task that processes this collection.
502+
503+
Args:
504+
task_name: Unique name for this task instance
505+
operator_id: Built-in operator identifier (e.g., "record_counter")
506+
output_collection_name: Name of the collection where task output will be stored
507+
params: Optional JSON string with operator-specific parameters
508+
509+
Returns:
510+
tuple: (success: bool, task_id: str)
511+
512+
Example:
513+
>>> success, task_id = collection.create_task(
514+
... name="count_docs",
515+
... operator_id="record_counter",
516+
... output_collection="doc_counts",
517+
... params=None
518+
... )
519+
"""
520+
return self._client.create_task(
521+
task_name=task_name,
522+
operator_id=operator_id,
523+
input_collection_id=self.id,
524+
output_collection_name=output_collection_name,
525+
params=params,
526+
tenant=self.tenant,
527+
database=self.database,
528+
)
529+
530+
def remove_task(
531+
self,
532+
task_name: str,
533+
delete_output: bool = False,
534+
) -> bool:
535+
"""Delete a task and prevent any further runs.
536+
537+
Args:
538+
task_name: Name of the task to remove
539+
delete_output: Whether to also delete the output collection. Defaults to False.
540+
541+
Returns:
542+
bool: True if successful
543+
544+
Example:
545+
>>> success = collection.remove_task("count_docs", delete_output=True)
546+
"""
547+
return self._client.remove_task(
548+
task_name=task_name,
549+
input_collection_id=self.id,
550+
delete_output=delete_output,
551+
tenant=self.tenant,
552+
database=self.database,
553+
)

chromadb/api/rust.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,7 @@ def _search(
320320
tenant: str = DEFAULT_TENANT,
321321
database: str = DEFAULT_DATABASE,
322322
) -> SearchResult:
323-
raise NotImplementedError(
324-
"Search is not implemented for Local Chroma"
325-
)
323+
raise NotImplementedError("Search is not implemented for Local Chroma")
326324

327325
@override
328326
def _count(
@@ -583,6 +581,38 @@ def get_settings(self) -> Settings:
583581
def get_max_batch_size(self) -> int:
584582
return self.bindings.get_max_batch_size()
585583

584+
@override
585+
def create_task(
586+
self,
587+
task_name: str,
588+
operator_id: str,
589+
input_collection_id: UUID,
590+
output_collection_name: str,
591+
params: Optional[str] = None,
592+
tenant: str = DEFAULT_TENANT,
593+
database: str = DEFAULT_DATABASE,
594+
) -> tuple[bool, str]:
595+
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
596+
raise NotImplementedError(
597+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
598+
"The Rust bindings (embedded mode) do not support task operations."
599+
)
600+
601+
@override
602+
def remove_task(
603+
self,
604+
task_name: str,
605+
input_collection_id: UUID,
606+
delete_output: bool = False,
607+
tenant: str = DEFAULT_TENANT,
608+
database: str = DEFAULT_DATABASE,
609+
) -> bool:
610+
"""Tasks are not supported in the Rust bindings (local embedded mode)."""
611+
raise NotImplementedError(
612+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
613+
"The Rust bindings (embedded mode) do not support task operations."
614+
)
615+
586616
# TODO: Remove this if it's not planned to be used
587617
@override
588618
def get_user_identity(self) -> UserIdentity:

chromadb/api/segment.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,7 @@ def _search(
427427
tenant: str = DEFAULT_TENANT,
428428
database: str = DEFAULT_DATABASE,
429429
) -> SearchResult:
430-
raise NotImplementedError(
431-
"Seach is not implemented for SegmentAPI"
432-
)
430+
raise NotImplementedError("Seach is not implemented for SegmentAPI")
433431

434432
@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
435433
@override
@@ -902,6 +900,38 @@ def get_settings(self) -> Settings:
902900
def get_max_batch_size(self) -> int:
903901
return self._producer.max_batch_size
904902

903+
@override
904+
def create_task(
905+
self,
906+
task_name: str,
907+
operator_id: str,
908+
input_collection_id: UUID,
909+
output_collection_name: str,
910+
params: Optional[str] = None,
911+
tenant: str = DEFAULT_TENANT,
912+
database: str = DEFAULT_DATABASE,
913+
) -> tuple[bool, str]:
914+
"""Tasks are not supported in the Segment API (local embedded mode)."""
915+
raise NotImplementedError(
916+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
917+
"The Segment API (embedded mode) does not support task operations."
918+
)
919+
920+
@override
921+
def remove_task(
922+
self,
923+
task_name: str,
924+
input_collection_id: UUID,
925+
delete_output: bool = False,
926+
tenant: str = DEFAULT_TENANT,
927+
database: str = DEFAULT_DATABASE,
928+
) -> bool:
929+
"""Tasks are not supported in the Segment API (local embedded mode)."""
930+
raise NotImplementedError(
931+
"Tasks are only supported when connecting to a Chroma server via HttpClient. "
932+
"The Segment API (embedded mode) does not support task operations."
933+
)
934+
905935
# TODO: This could potentially cause race conditions in a distributed version of the
906936
# system, since the cache is only local.
907937
# TODO: promote collection -> topic to a base class method so that it can be

examples/task_api_example.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Example: Using Chroma's Task API to process collections automatically
4+
5+
This demonstrates how to register tasks that automatically process
6+
collections as new records are added.
7+
"""
8+
9+
import chromadb
10+
11+
# Connect to Chroma server
12+
client = chromadb.HttpClient(host="localhost", port=8000)
13+
# ignore error if collection does not exist
14+
try:
15+
client.delete_collection("my_documents_counts")
16+
except Exception:
17+
pass
18+
# Create or get a collection
19+
collection = client.get_or_create_collection(
20+
name="my_document", metadata={"description": "Sample documents for task processing"}
21+
)
22+
23+
# Add some sample documents
24+
collection.add(
25+
ids=["doc1", "doc2", "doc3"],
26+
documents=[
27+
"The quick brown fox jumps over the lazy dog",
28+
"Machine learning is a subset of artificial intelligence",
29+
"Python is a popular programming language",
30+
],
31+
metadatas=[{"source": "proverb"}, {"source": "tech"}, {"source": "tech"}],
32+
)
33+
34+
print(f"✅ Created collection '{collection.name}' with {collection.count()} documents")
35+
36+
# Create a task that counts records in the collection
37+
# The 'record_counter' operator processes each record and outputs {"count": N}
38+
success, task_id = collection.create_task(
39+
task_name="count_my_docs",
40+
operator_id="record_counter", # Built-in operator that counts records
41+
output_collection_name="my_documents_counts", # Auto-created
42+
params=None, # No additional parameters needed
43+
)
44+
45+
if success:
46+
print("✅ Task created successfully!")
47+
print(f" Task ID: {task_id}")
48+
print(" Task name: count_my_docs")
49+
print(f" Input collection: {collection.name}")
50+
print(" Output collection: my_documents_counts")
51+
print(" Operator: record_counter")
52+
else:
53+
print("❌ Failed to create task")
54+
55+
# The task will now run automatically when:
56+
# 1. New documents are added to 'my_documents'
57+
# 2. The number of new records >= min_records_for_task (default: 100)
58+
59+
print("\n" + "=" * 60)
60+
print("Task is now registered and will run on new data!")
61+
print("=" * 60)
62+
63+
# Add more documents to trigger task execution
64+
print("\nAdding more documents...")
65+
collection.add(
66+
ids=["doc4", "doc5"],
67+
documents=["Chroma is a vector database", "Tasks automate data processing"],
68+
)
69+
70+
print(f"Collection now has {collection.count()} documents")
71+
72+
# Later, you can remove the task
73+
print("\n" + "=" * 60)
74+
input("Press Enter to remove the task...")
75+
76+
success = collection.remove_task(
77+
task_name="count_my_docs", delete_output=True # Also delete the output collection
78+
)
79+
80+
if success:
81+
print("✅ Task removed successfully!")
82+
else:
83+
print("❌ Failed to remove task")

0 commit comments

Comments
 (0)