diff --git a/README.md b/README.md index ad6b910..d9df2a8 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,20 @@ generations = dalle.generate_from_masked_image( 🙌 Task completed! ``` +## Dry runs + +If you are playing around and don't want to waste credits, you can use the dry run option, eg + +``` +dalle.generate(..., dry=True) +dalle.generate_amount(...., dry=True) +dalle.generate_and_download(...., dry=True) +``` + +This will circumvent the call to openAI but will still generate some random data so if you use the +output of `generate` to `download`, it will download random images. + + # Other languages [Nodejs Package](https://github.com/ezzcodeezzlife/dalle-node) diff --git a/src/dalle2/dalle2.py b/src/dalle2/dalle2.py index 44bd126..505496b 100644 --- a/src/dalle2/dalle2.py +++ b/src/dalle2/dalle2.py @@ -6,6 +6,7 @@ import time import urllib import urllib.request +import uuid from pathlib import Path @@ -16,7 +17,7 @@ def __init__(self, bearer): self.inpainting_batch_size = 3 self.task_sleep_seconds = 3 - def generate(self, prompt): + def generate(self, prompt, dry=False): body = { "task_type": "text2im", "prompt": { @@ -25,22 +26,22 @@ def generate(self, prompt): } } - return self.get_task_response(body) + return self.get_task_response(body, dry=dry) - def generate_and_download(self, prompt, image_dir=os.getcwd()): - generations = self.generate(prompt) + def generate_and_download(self, prompt, image_dir=os.getcwd(), dry=False): + generations = self.generate(prompt, dry=dry) if not generations: return None return self.download(generations, image_dir) - def generate_amount(self, prompt, amount): + def generate_amount(self, prompt, amount, dry=False): if amount < self.batch_size: raise ValueError(f"passed amount of {amount} cannot be smaller than the batch size of {self.batch_size}") - return [self.generate(prompt) for _ in range(math.ceil(amount / self.batch_size))] + return [self.generate(prompt, dry=dry) for _ in range(math.ceil(amount / self.batch_size))] - def generate_from_masked_image(self, prompt, image_path): + def generate_from_masked_image(self, prompt, image_path, dry=False): with open(image_path, "rb") as f: image_base64 = base64.b64encode(f.read()) @@ -54,16 +55,22 @@ def generate_from_masked_image(self, prompt, image_path): } } - return self.get_task_response(body) + return self.get_task_response(body, dry=dry) - def get_task_response(self, body): + def get_task_response(self, body, dry=False): url = "https://labs.openai.com/api/labs/tasks" headers = { 'Authorization': "Bearer " + self.bearer, 'Content-Type': "application/json", } - response = requests.post(url, headers=headers, data=json.dumps(body)) + if dry: + return [{ # fake openai response so download() still works + 'id': str(uuid.uuid4()), + 'generation': {'image_path': 'https://picsum.photos/200'} + } for _ in range(4)] + else: + response = requests.post(url, headers=headers, data=json.dumps(body)) if response.status_code != 200: print(response.text) return None