|
| 1 | +from __future__ import print_function, unicode_literals, absolute_import, division |
| 2 | +import logging |
| 3 | +from typing import Union, Tuple |
| 4 | + |
| 5 | +logger = logging.getLogger(__name__) |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from mako.template import Template |
| 9 | +from numbers import Number |
| 10 | +from gputools import OCLArray, OCLProgram |
| 11 | +from gputools.core.ocltypes import cl_buffer_datatype_dict |
| 12 | +from ._abspath import abspath |
| 13 | + |
| 14 | + |
| 15 | +def _rank_downscale_2d(data_g, size=(3, 3), rank=None, cval = 0, res_g=None): |
| 16 | + if not data_g.dtype.type in cl_buffer_datatype_dict: |
| 17 | + raise ValueError("dtype %s not supported" % data_g.dtype.type) |
| 18 | + |
| 19 | + DTYPE = cl_buffer_datatype_dict[data_g.dtype.type] |
| 20 | + |
| 21 | + size = tuple(map(int, size)) |
| 22 | + |
| 23 | + if rank is None: |
| 24 | + rank = np.prod(size) // 2 |
| 25 | + |
| 26 | + with open(abspath("kernels/rank_downscale.cl"), "r") as f: |
| 27 | + tpl = Template(f.read()) |
| 28 | + |
| 29 | + rendered = tpl.render(DTYPE = DTYPE,FSIZE_Z=0, FSIZE_X=size[1], FSIZE_Y=size[0],CVAL = cval) |
| 30 | + |
| 31 | + prog = OCLProgram(src_str=rendered) |
| 32 | + |
| 33 | + if res_g is None: |
| 34 | + res_g = OCLArray.empty(tuple(s0//s for s, s0 in zip(size,data_g.shape)), data_g.dtype) |
| 35 | + |
| 36 | + prog.run_kernel("rank_2", res_g.shape[::-1], None, data_g.data, res_g.data, |
| 37 | + np.int32(data_g.shape[1]), np.int32(data_g.shape[0]), |
| 38 | + np.int32(rank)) |
| 39 | + return res_g |
| 40 | + |
| 41 | +def _rank_downscale_3d(data_g, size=(3, 3, 3), rank=None, cval = 0, res_g=None): |
| 42 | + if not data_g.dtype.type in cl_buffer_datatype_dict: |
| 43 | + raise ValueError("dtype %s not supported" % data_g.dtype.type) |
| 44 | + |
| 45 | + DTYPE = cl_buffer_datatype_dict[data_g.dtype.type] |
| 46 | + |
| 47 | + size = tuple(map(int, size)) |
| 48 | + |
| 49 | + if rank is None: |
| 50 | + rank = np.prod(size) // 2 |
| 51 | + |
| 52 | + with open(abspath("kernels/rank_downscale.cl"), "r") as f: |
| 53 | + tpl = Template(f.read()) |
| 54 | + |
| 55 | + rendered = tpl.render(DTYPE = DTYPE,FSIZE_X=size[2], FSIZE_Y=size[1], FSIZE_Z=size[0],CVAL = cval) |
| 56 | + |
| 57 | + prog = OCLProgram(src_str=rendered) |
| 58 | + |
| 59 | + if res_g is None: |
| 60 | + res_g = OCLArray.empty(tuple(s0//s for s, s0 in zip(size,data_g.shape)), data_g.dtype) |
| 61 | + |
| 62 | + prog.run_kernel("rank_3", res_g.shape[::-1], None, data_g.data, res_g.data, |
| 63 | + np.int32(data_g.shape[2]), np.int32(data_g.shape[1]), np.int32(data_g.shape[0]), |
| 64 | + np.int32(rank)) |
| 65 | + return res_g |
| 66 | + |
| 67 | + |
| 68 | +def rank_downscale(data:np.ndarray, size:Union[int, Tuple[int]]=3, rank=None): |
| 69 | + """ |
| 70 | + downscales an image by the given factor and returns the rank-th element in each block |
| 71 | +
|
| 72 | + Parameters |
| 73 | + ---------- |
| 74 | + data: numpy.ndarray |
| 75 | + input data (2d or 3d) |
| 76 | + size: int or tuple |
| 77 | + downsampling factors |
| 78 | + rank: int |
| 79 | + rank of element to retain |
| 80 | + rank = 0 -> minimum |
| 81 | + rank = -1 -> maximum |
| 82 | + rank = None -> median |
| 83 | + |
| 84 | + Returns |
| 85 | + ------- |
| 86 | + downscaled image |
| 87 | + """ |
| 88 | + |
| 89 | + if not (isinstance(data, np.ndarray) and data.ndim in (2,3)): |
| 90 | + raise ValueError("input data has to be a 2d or 3d numpy array!") |
| 91 | + |
| 92 | + if isinstance(size, Number): |
| 93 | + size = (int(size),)*data.ndim |
| 94 | + |
| 95 | + if len(size) != data.ndim: |
| 96 | + raise ValueError("size has to be a tuple of 3 ints") |
| 97 | + |
| 98 | + if rank is None: |
| 99 | + rank = np.prod(size) // 2 |
| 100 | + else: |
| 101 | + rank = rank % np.prod(size) |
| 102 | + |
| 103 | + data_g = OCLArray.from_array(data) |
| 104 | + |
| 105 | + if data.ndim==2: |
| 106 | + res_g = _rank_downscale_2d(data_g, size=size, rank=rank) |
| 107 | + elif data.ndim==3: |
| 108 | + res_g = _rank_downscale_3d(data_g, size=size, rank=rank) |
| 109 | + else: |
| 110 | + raise ValueError("data has to be 2d or 3d") |
| 111 | + |
| 112 | + return res_g.get() |
| 113 | + |
| 114 | + |
0 commit comments