Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sec #114

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/validator.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ If you want to run validation APIs locally, check out [Setup validator endpoint]
3. (Optional) **Enable Auto Update Validator**
```
pm2 start auto_update.sh --name "auto-update"
```
```

22 changes: 20 additions & 2 deletions image_generation_subnet/protocol.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time
import json
import bittensor as bt
import pydantic
from generation_models.utils import base64_to_pil_image
Expand Down Expand Up @@ -99,7 +101,7 @@ def deserialize_response(self):
"response_dict": self.response_dict,
}

def store_response(self, storage_url: str, uid, validator_uid):
def store_response(self, storage_url: str, uid, validator_uid, keypair: bt.Keypair):
if self.model_name == "GoJourney":
storage_url = storage_url + "/upload-go-journey-item"
data = {
Expand Down Expand Up @@ -128,6 +130,14 @@ def store_response(self, storage_url: str, uid, validator_uid):
"pipeline_params": self.pipeline_params,
}
}
serialized_data = json.dumps(data, sort_keys=True, separators=(',', ':'))
nonce = str(time.time_ns())
# Calculate validator 's signature
message = f"{serialized_data}{keypair.ss58_address}{nonce}"
signature = f"0x{keypair.sign(message).hex()}"
# Add validator 's signature
data["nonce"] = nonce
data["signature"] = signature
try:
response = requests.post(storage_url, json=data)
response.raise_for_status()
Expand Down Expand Up @@ -301,7 +311,7 @@ def deserialize_response(self):
"model_name": self.model_name,
}

def store_response(self, storage_url: str, uid, validator_uid):
def store_response(self, storage_url: str, uid, validator_uid, keypair: bt.Keypair):
storage_url = storage_url + "/upload-multimodal-item"
minimized_prompt_output: dict = copy.deepcopy(self.prompt_output)
minimized_prompt_output['choices'][0].pop("logprobs")
Expand All @@ -317,6 +327,14 @@ def store_response(self, storage_url: str, uid, validator_uid):
"pipeline_params": self.pipeline_params,
}
}
serialized_data = json.dumps(data, sort_keys=True, separators=(',', ':'))
nonce = str(time.time_ns())
# Calculate validator 's signature
message = f"{serialized_data}{keypair.ss58_address}{nonce}"
signature = f"0x{keypair.sign(message).hex()}"
# Add validator 's signature
data["nonce"] = nonce
data["signature"] = signature
try:
response = requests.post(storage_url, json=data)
response.raise_for_status()
Expand Down
36 changes: 24 additions & 12 deletions image_generation_subnet/validator/miner_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import time
import bittensor as bt
from image_generation_subnet.protocol import ImageGenerating, Information
import torch
Expand Down Expand Up @@ -150,21 +152,31 @@ def get_model_specific_weights(self, model_name, normalize=True):
return model_specific_weights

def store_miner_info(self):
catalogue = {}
for k, v in self.validator.nicheimage_catalogue.items():
catalogue[k] = {
"model_incentive_weight": v.get("model_incentive_weight", 0),
"supporting_pipelines": v.get("supporting_pipelines", []),
}
data = {
"uid": self.validator.uid,
"info": self.all_uids_info,
"version": ig_subnet.__version__,
"catalogue": catalogue,
}
serialized_data = json.dumps(data, sort_keys=True, separators=(',', ':'))
nonce = str(time.time_ns())
# Calculate validator 's signature
keypair = self.validator.wallet.hotkey
message = f"{serialized_data}{keypair.ss58_address}{nonce}"
signature = f"0x{keypair.sign(message).hex()}"
# Add validator 's signature
data["nonce"] = nonce
data["signature"] = signature
try:
catalogue = {}
for k, v in self.validator.nicheimage_catalogue.items():
catalogue[k] = {
"model_incentive_weight": v.get("model_incentive_weight", 0),
"supporting_pipelines": v.get("supporting_pipelines", []),
}
requests.post(
self.validator.config.storage_url + "/store_miner_info",
json={
"uid": self.validator.uid,
"info": self.all_uids_info,
"version": ig_subnet.__version__,
"catalogue": catalogue,
},
json=data
)
self.reset_metadata()
except Exception as e:
Expand Down
10 changes: 5 additions & 5 deletions neurons/validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def async_query_and_reward(
)
store_thread = threading.Thread(
target=self.store_miner_output,
args=(self.config.storage_url, responses, uids, self.uid),
args=(self.config.storage_url, responses, uids),
daemon=True,
)
store_thread.start()
Expand Down Expand Up @@ -683,16 +683,16 @@ def prepare_challenge(self, uids_should_rewards, model_name, pipeline_type):
return synapses, batched_uids_should_rewards

def store_miner_output(
self, storage_url, responses: list[bt.Synapse], uids, validator_uid
self, storage_url, responses: list[bt.Synapse], uids
):
if not self.config.share_response:
return

for uid, response in enumerate(responses):
for uid, response in zip(uids, responses):
if not response.is_success:
continue
try:
response.store_response(storage_url, uid, validator_uid)
response.store_response(storage_url, uid, self.uid, self.wallet.hotkey)
break
except Exception as e:
bt.logging.error(f"Error in storing response: {e}")
Expand Down
29 changes: 18 additions & 11 deletions neurons/validator/validator_proxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from fastapi import FastAPI, HTTPException, Depends
from concurrent.futures import ThreadPoolExecutor
import uvicorn
Expand All @@ -16,12 +17,8 @@
from starlette.concurrency import run_in_threadpool
import threading


class ValidatorProxy:
def __init__(
self,
validator,
):
def __init__(self, validator: "neurons.validator.validator.Validator"):
self.validator = validator
self.get_credentials()
self.miner_request_counter = {}
Expand All @@ -41,16 +38,26 @@ def __init__(
self.start_server()

def get_credentials(self):
postfix = (
f":{self.validator.config.proxy.port}/validator_proxy"
if self.validator.config.proxy.port
else ""
)
ss58_address = self.validator.wallet.hotkey.ss58_address
uid = self.validator.uid
nonce = str(time.time_ns())
# Calculate validator 's signature
message = f"{postfix}{ss58_address}{nonce}"
signature = f"0x{self.validator.wallet.hotkey.sign(message).hex()}"

with httpx.Client(timeout=httpx.Timeout(30)) as client:
response = client.post(
f"{self.validator.config.proxy.proxy_client_url}/get_credentials",
json={
"postfix": (
f":{self.validator.config.proxy.port}/validator_proxy"
if self.validator.config.proxy.port
else ""
),
"uid": self.validator.uid,
"postfix": postfix,
"uid": uid,
"signature": signature,
"nonce": nonce
},
)
response.raise_for_status()
Expand Down