-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpipeline.py
49 lines (39 loc) · 1.7 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import numpy as np
import PIL.Image as Image
from abc import ABC, abstractmethod
from diffusion.base import BaseSampler
from methods.base import BaseGuidance
from evaluations.base import BaseEvaluator
from utils.configs import Arguments
import logger
class BasePipeline(object):
def __init__(self,
args: Arguments,
network: BaseSampler,
guider: BaseGuidance,
evaluator: BaseEvaluator):
self.network = network
self.guider = guider
self.evaluator = evaluator
self.logging_dir = args.logging_dir
self.check_done = args.check_done
@abstractmethod
def sample(self, sample_size: int):
samples = self.check_done_and_load_sample()
if samples is None:
samples = self.network.sample(sample_size=sample_size, guidance=self.guider)
samples = self.network.tensor_to_obj(samples)
return samples
def evaluate(self, samples):
return self.check_done_and_evaluate(samples)
def check_done_and_evaluate(self, samples):
if self.check_done and os.path.exists(os.path.join(self.logging_dir, 'metrics.json')):
logger.log("Metrics already generated. To regenerate, please set `check_done` to `False`.")
return None
return self.evaluator.evaluate(samples)
def check_done_and_load_sample(self):
if self.check_done and os.path.exists(os.path.join(self.logging_dir, "finished_sampling")):
logger.log("found tags for generated samples, should load directly. To regenerate, please set `check_done` to `False`.")
return logger.load_samples()
return None