Skip to content

Commit 1dbc46d

Browse files
committed
Enhanced API structuring to support common AI API interfaces
1 parent 644837e commit 1dbc46d

File tree

3 files changed

+191
-61
lines changed

3 files changed

+191
-61
lines changed

tensorlink/api/models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pydantic import BaseModel
2+
from typing import Optional, List, Literal
3+
4+
5+
class NodeRequest(BaseModel):
6+
address: str
7+
8+
9+
class JobRequest(BaseModel):
10+
hf_name: str
11+
time: int = 1800
12+
payment: int = 0
13+
14+
15+
class GenerationRequest(BaseModel):
16+
hf_name: str
17+
message: str
18+
prompt: str = None
19+
max_length: int = 2048
20+
max_new_tokens: int = 2048
21+
temperature: float = 0.4
22+
do_sample: bool = True
23+
num_beams: int = 4
24+
history: Optional[List[dict]] = None
25+
output: str = None
26+
processing: bool = False
27+
id: int = None
28+
response_format: Literal["simple", "openai", "full"] = "full"
29+
30+
31+
class ModelStatusResponse(BaseModel):
32+
model_name: str
33+
status: str # "loaded", "loading", "not_loaded", "error"
34+
message: str

tensorlink/api/node.py

Lines changed: 156 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from tensorlink.ml.utils import get_popular_model_stats
2+
from tensorlink.ml.validator import extract_assistant_response
3+
from tensorlink.api.models import (
4+
JobRequest,
5+
GenerationRequest,
6+
ModelStatusResponse,
7+
)
28

39
from fastapi import FastAPI, HTTPException, APIRouter, Request, Query
4-
from pydantic import BaseModel
5-
from typing import Optional, List
610
from collections import defaultdict
711
from threading import Thread
812
import logging
@@ -13,35 +17,73 @@
1317
import time
1418

1519

16-
class NodeRequest(BaseModel):
17-
address: str
18-
19-
20-
class JobRequest(BaseModel):
21-
hf_name: str
22-
time: int = 1800
23-
payment: int = 0
24-
25-
26-
class GenerationRequest(BaseModel):
27-
hf_name: str
28-
message: str
29-
prompt: str = None
30-
max_length: int = 2048
31-
max_new_tokens: int = 2048
32-
temperature: float = 0.4
33-
do_sample: bool = True
34-
num_beams: int = 4
35-
history: Optional[List[dict]] = None
36-
output: str = None
37-
processing: bool = False
38-
id: int = None
39-
40-
41-
class ModelStatusResponse(BaseModel):
42-
model_name: str
43-
status: str # "loaded", "loading", "not_loaded", "error"
44-
message: str
20+
def _format_response(
21+
request: GenerationRequest,
22+
processing_time: float,
23+
request_id: str,
24+
):
25+
"""
26+
Format the response based on the requested format type.
27+
28+
Args:
29+
request: The original generation request with output
30+
processing_time: Time taken to process the request
31+
request_id: Unique identifier for this request
32+
33+
Returns:
34+
Dictionary formatted according to response_format
35+
"""
36+
timestamp = int(time.time())
37+
38+
# Extract clean text from output
39+
clean_output = extract_assistant_response(request.output, request.hf_name)
40+
41+
if request.response_format == "simple":
42+
# Minimal response - just the text
43+
return {"response": clean_output}
44+
45+
elif request.response_format == "openai":
46+
# OpenAI-compatible format
47+
return {
48+
"id": request_id,
49+
"object": "chat.completion",
50+
"created": timestamp,
51+
"model": request.hf_name,
52+
"choices": [
53+
{
54+
"index": 0,
55+
"message": {"role": "assistant", "content": clean_output},
56+
"finish_reason": "stop",
57+
}
58+
],
59+
"usage": {
60+
"prompt_tokens": -1, # Not tracked in current implementation
61+
"completion_tokens": -1,
62+
"total_tokens": -1,
63+
},
64+
}
65+
66+
else: # "full" format (default, comprehensive response with all metadata)
67+
return {
68+
"id": request_id,
69+
"model": request.hf_name,
70+
"response": clean_output,
71+
"raw_output": request.output,
72+
"created": timestamp,
73+
"processing_time": round(processing_time, 3),
74+
"generation_params": {
75+
"max_length": request.max_length,
76+
"max_new_tokens": request.max_new_tokens,
77+
"temperature": request.temperature,
78+
"do_sample": request.do_sample,
79+
"num_beams": request.num_beams,
80+
},
81+
"metadata": {
82+
"has_history": bool(request.history),
83+
"history_length": len(request.history) if request.history else 0,
84+
"prompt_used": request.prompt is not None,
85+
},
86+
}
4587

4688

4789
class TensorlinkAPI:
@@ -65,6 +107,8 @@ def _define_routes(self):
65107
@self.router.post("/generate")
66108
async def generate(request: GenerationRequest):
67109
try:
110+
start_time = time.time()
111+
68112
# Log model request
69113
current_time = time.time()
70114
self.model_request_timestamps[request.hf_name].append(current_time)
@@ -82,7 +126,8 @@ async def generate(request: GenerationRequest):
82126
self.model_name_to_request[request.hf_name] += 1
83127

84128
request.output = None
85-
request.id = hash(random.random())
129+
request_id = f"req_{hash(random.random())}"
130+
request.id = hash(request_id)
86131

87132
# Check if model is loaded, if not trigger loading
88133
model_status = self._check_model_status(request.hf_name)
@@ -105,8 +150,74 @@ async def generate(request: GenerationRequest):
105150
# Wait for the result
106151
request = await self._wait_for_result(request)
107152

108-
return_val = request.output
109-
return {"response": return_val}
153+
processing_time = time.time() - start_time
154+
155+
# Format response based on requested format
156+
formatted_response = _format_response(
157+
request, processing_time, request_id
158+
)
159+
160+
return formatted_response
161+
162+
except HTTPException:
163+
raise
164+
except Exception as e:
165+
raise HTTPException(status_code=500, detail=str(e))
166+
167+
@self.router.post("/v1/chat/completions")
168+
async def chat_completions(request: Request):
169+
"""
170+
OpenAI-compatible chat completions endpoint.
171+
Accepts OpenAI format and returns OpenAI format.
172+
"""
173+
try:
174+
body = await request.json()
175+
176+
# Extract OpenAI-style parameters
177+
model = body.get("model")
178+
messages = body.get("messages", [])
179+
temperature = body.get("temperature", 0.7)
180+
max_tokens = body.get("max_tokens", 2048)
181+
182+
# Convert to our internal format
183+
history = []
184+
current_message = ""
185+
186+
for msg in messages:
187+
role = msg.get("role")
188+
content = msg.get("content", "")
189+
190+
if role == "system":
191+
# System messages added to history
192+
history.append({"role": "system", "content": content})
193+
elif role == "user":
194+
# Last user message becomes current_message
195+
if (
196+
current_message
197+
): # If there was a previous user message, add to history
198+
history.append({"role": "user", "content": current_message})
199+
current_message = content
200+
elif role == "assistant":
201+
history.append({"role": "assistant", "content": content})
202+
203+
if not current_message:
204+
raise HTTPException(
205+
status_code=400,
206+
detail="No user message found in messages array",
207+
)
208+
209+
# Create our internal request with OpenAI format
210+
gen_request = GenerationRequest(
211+
hf_name=model,
212+
message=current_message,
213+
history=history if history else None,
214+
temperature=temperature,
215+
max_new_tokens=max_tokens,
216+
response_format="openai",
217+
)
218+
219+
# Reuse the generate logic
220+
return await generate(gen_request)
110221

111222
except HTTPException:
112223
raise
@@ -249,26 +360,8 @@ async def get_proposals(limit: int = Query(30, ge=1, le=180)):
249360
@self.app.get("/node-info")
250361
async def get_node_info(node_id: str):
251362
"""
252-
{
253-
pubKeyHash: '0x742d35Cc6634C0532925a3b844Bc9e7595f0bEb1',
254-
type: 'validator',
255-
lastSeen: '2 minutes ago',
256-
data: {
257-
peers: 12,
258-
rewards: 1000.5,
259-
is_active: true
260-
}
261-
},
262-
{
263-
pubKeyHash: '0x8f3B9c4A7E2D1F5C6A8B9D0E3F4A5B6C7D8E9F0A',
264-
type: 'worker',
265-
lastSeen: '5 minutes ago',
266-
data: {
267-
jobs_completed: 47,
268-
rewards: 235.8,
269-
is_active: true
270-
}
271-
},
363+
Get information about a specific node in the network.
364+
Returns node type, last seen, and relevant data based on role.
272365
"""
273366
node_info = self.smart_node.dht.query(node_id)
274367
if node_info:
@@ -280,9 +373,10 @@ async def get_node_info(node_id: str):
280373
}
281374

282375
if node_info["role"] == "V":
283-
# node_info["peers"] = 1
376+
# Validator-specific data
284377
pass
285378
elif node_info["role"] == "W":
379+
# Worker-specific data
286380
node_info["rewards"] = (
287381
self.smart_node.contract_manager.get_worker_claim_data(
288382
node_info["address"]
@@ -294,25 +388,31 @@ async def get_node_info(node_id: str):
294388

295389
@self.app.get("/claim-info")
296390
async def get_worker_claims(node_address: str):
391+
"""Get claim information for a specific worker node"""
297392
return self.smart_node.contract_manager.get_worker_claim_data(node_address)
298393

299394
self.app.include_router(self.router)
300395

301396
def _check_model_status(self, model_name: str) -> dict:
302397
"""Check if a model is loaded, loading, or not loaded"""
303398
status = "not_loaded"
399+
message = "Model is not currently loaded"
304400

305401
try:
306402
# Check if there is a public job with this module
307403
for module_id, module in self.smart_node.modules.items():
308404
if module.get("name", "") == model_name:
309405
if module.get("public", False):
310406
status = "loaded"
407+
message = f"Model {model_name} is loaded and ready"
408+
break
311409

312410
except Exception as e:
313411
logging.error(f"Error checking model status: {e}")
412+
status = "error"
413+
message = f"Error checking model status: {str(e)}"
314414

315-
return {"status": status, "message": "Model is not currently loaded"}
415+
return {"status": status, "message": message}
316416

317417
def _trigger_model_load(self, model_name: str):
318418
"""Trigger the ML validator to load a specific model"""
@@ -321,10 +421,6 @@ def _trigger_model_load(self, model_name: str):
321421
self.api_requested_models.add(model_name)
322422
self.smart_node.create_hf_job(model_name)
323423

324-
# TODO Send load request to ML validator
325-
# self.smart_node.request_queue.put(
326-
# {"type": "load_model", "args": (model_name,)}
327-
# )
328424
except Exception as e:
329425
logging.error(f"Error triggering model load: {e}")
330426

tensorlink/ml/validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from tensorlink.ml.worker import DistributedWorker
33
from tensorlink.ml.module import DistributedModel
44
from tensorlink.ml.utils import load_models_cache, save_models_cache
5-
from tensorlink.api.node import GenerationRequest
5+
from tensorlink.api.models import GenerationRequest
66

77
from transformers import AutoTokenizer
88
import torch

0 commit comments

Comments
 (0)