Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jan 12, 2025
1 parent e7b98f5 commit fcc699a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion exo/inference/dummy_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) ->
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: dict = {}) -> tuple[np.ndarray, Optional[dict]]:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data, None

Expand Down
4 changes: 2 additions & 2 deletions exo/networking/grpc/grpc_peer_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)

Expand All @@ -101,7 +101,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=self.serialize_inference_state(inference_state)
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)

Expand Down
4 changes: 2 additions & 2 deletions exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def SendPrompt(self, request, context):
)
prompt = request.prompt
request_id = request.request_id
inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
Expand All @@ -68,7 +68,7 @@ async def SendTensor(self, request, context):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id

inference_state = self.deserialize_inference_state(request.inference_state)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)

result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
Expand Down

0 comments on commit fcc699a

Please sign in to comment.