diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 9409a2f..b3aca11 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -56,7 +56,10 @@ def llava_strict_satisfaction(): from io import BytesIO import pickle + # this batch size is related to sample_batch_size + # set it with the sample_batch_size in your comfig batch_size = 4 + url = "http://127.0.0.1:8085" sess = requests.Session() retries = Retry( @@ -124,7 +127,10 @@ def llava_bertscore(): from io import BytesIO import pickle + # this batch size is related to sample_batch_size + # set it with the sample_batch_size in your comfig batch_size = 16 + url = "http://127.0.0.1:8085" sess = requests.Session() retries = Retry(