1919import abc
2020import functools
2121from types import ModuleType
22- from typing import Any , Sequence
22+ from typing import TYPE_CHECKING , Any , Callable , Sequence , cast
23+
24+ if TYPE_CHECKING :
25+ import numpy .typing as npt
26+
27+ from . import _imaging
2328
2429
2530class Filter :
2631 @abc .abstractmethod
27- def filter (self , image ) :
32+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
2833 pass
2934
3035
@@ -33,7 +38,9 @@ class MultibandFilter(Filter):
3338
3439
3540class BuiltinFilter (MultibandFilter ):
36- def filter (self , image ):
41+ filterargs : tuple [Any , ...]
42+
43+ def filter (self , image : _imaging .ImagingCore ) -> _imaging .ImagingCore :
3744 if image .mode == "P" :
3845 msg = "cannot filter palette images"
3946 raise ValueError (msg )
@@ -91,7 +98,7 @@ def __init__(self, size: int, rank: int) -> None:
9198 self .size = size
9299 self .rank = rank
93100
94- def filter (self , image ) :
101+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
95102 if image .mode == "P" :
96103 msg = "cannot filter palette images"
97104 raise ValueError (msg )
@@ -158,7 +165,7 @@ class ModeFilter(Filter):
158165 def __init__ (self , size : int = 3 ) -> None :
159166 self .size = size
160167
161- def filter (self , image ) :
168+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
162169 return image .modefilter (self .size )
163170
164171
@@ -176,9 +183,9 @@ class GaussianBlur(MultibandFilter):
176183 def __init__ (self , radius : float | Sequence [float ] = 2 ) -> None :
177184 self .radius = radius
178185
179- def filter (self , image ) :
186+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
180187 xy = self .radius
181- if not isinstance (xy , (tuple , list )):
188+ if isinstance (xy , (int , float )):
182189 xy = (xy , xy )
183190 if xy == (0 , 0 ):
184191 return image .copy ()
@@ -208,9 +215,9 @@ def __init__(self, radius: float | Sequence[float]) -> None:
208215 raise ValueError (msg )
209216 self .radius = radius
210217
211- def filter (self , image ) :
218+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
212219 xy = self .radius
213- if not isinstance (xy , (tuple , list )):
220+ if isinstance (xy , (int , float )):
214221 xy = (xy , xy )
215222 if xy == (0 , 0 ):
216223 return image .copy ()
@@ -241,7 +248,7 @@ def __init__(
241248 self .percent = percent
242249 self .threshold = threshold
243250
244- def filter (self , image ) :
251+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
245252 return image .unsharp_mask (self .radius , self .percent , self .threshold )
246253
247254
@@ -387,8 +394,13 @@ class Color3DLUT(MultibandFilter):
387394 name = "Color 3D LUT"
388395
389396 def __init__ (
390- self , size , table , channels : int = 3 , target_mode : str | None = None , ** kwargs
391- ):
397+ self ,
398+ size : int | tuple [int , int , int ],
399+ table ,
400+ channels : int = 3 ,
401+ target_mode : str | None = None ,
402+ ** kwargs : bool ,
403+ ) -> None :
392404 if channels not in (3 , 4 ):
393405 msg = "Only 3 or 4 output channels are supported"
394406 raise ValueError (msg )
@@ -410,15 +422,16 @@ def __init__(
410422 pass
411423
412424 if numpy and isinstance (table , numpy .ndarray ):
425+ numpy_table : npt .NDArray [Any ] = table
413426 if copy_table :
414- table = table .copy ()
427+ numpy_table = numpy_table .copy ()
415428
416- if table .shape in [
429+ if numpy_table .shape in [
417430 (items * channels ,),
418431 (items , channels ),
419432 (size [2 ], size [1 ], size [0 ], channels ),
420433 ]:
421- table = table .reshape (items * channels )
434+ table = numpy_table .reshape (items * channels )
422435 else :
423436 wrong_size = True
424437
@@ -428,15 +441,17 @@ def __init__(
428441
429442 # Convert to a flat list
430443 if table and isinstance (table [0 ], (list , tuple )):
431- table , raw_table = [], table
444+ raw_table = cast (Sequence [Sequence [int ]], table )
445+ flat_table : list [int ] = []
432446 for pixel in raw_table :
433447 if len (pixel ) != channels :
434448 msg = (
435449 "The elements of the table should "
436450 f"have a length of { channels } ."
437451 )
438452 raise ValueError (msg )
439- table .extend (pixel )
453+ flat_table .extend (pixel )
454+ table = flat_table
440455
441456 if wrong_size or len (table ) != items * channels :
442457 msg = (
@@ -449,23 +464,29 @@ def __init__(
449464 self .table = table
450465
451466 @staticmethod
452- def _check_size (size : Any ) -> list [ int ]:
467+ def _check_size (size : Any ) -> tuple [ int , int , int ]:
453468 try :
454469 _ , _ , _ = size
455470 except ValueError as e :
456471 msg = "Size should be either an integer or a tuple of three integers."
457472 raise ValueError (msg ) from e
458473 except TypeError :
459474 size = (size , size , size )
460- size = [ int (x ) for x in size ]
475+ size = tuple ( int (x ) for x in size )
461476 for size_1d in size :
462477 if not 2 <= size_1d <= 65 :
463478 msg = "Size should be in [2, 65] range."
464479 raise ValueError (msg )
465480 return size
466481
467482 @classmethod
468- def generate (cls , size , callback , channels = 3 , target_mode = None ):
483+ def generate (
484+ cls ,
485+ size : int | tuple [int , int , int ],
486+ callback : Callable [[float , float , float ], tuple [float , ...]],
487+ channels : int = 3 ,
488+ target_mode : str | None = None ,
489+ ) -> Color3DLUT :
469490 """Generates new LUT using provided callback.
470491
471492 :param size: Size of the table. Passed to the constructor.
@@ -482,7 +503,7 @@ def generate(cls, size, callback, channels=3, target_mode=None):
482503 msg = "Only 3 or 4 output channels are supported"
483504 raise ValueError (msg )
484505
485- table = [0 ] * (size_1d * size_2d * size_3d * channels )
506+ table : list [ float ] = [0 ] * (size_1d * size_2d * size_3d * channels )
486507 idx_out = 0
487508 for b in range (size_3d ):
488509 for g in range (size_2d ):
@@ -500,7 +521,13 @@ def generate(cls, size, callback, channels=3, target_mode=None):
500521 _copy_table = False ,
501522 )
502523
503- def transform (self , callback , with_normals = False , channels = None , target_mode = None ):
524+ def transform (
525+ self ,
526+ callback : Callable [..., tuple [float , ...]],
527+ with_normals : bool = False ,
528+ channels : int | None = None ,
529+ target_mode : str | None = None ,
530+ ) -> Color3DLUT :
504531 """Transforms the table values using provided callback and returns
505532 a new LUT with altered values.
506533
@@ -564,7 +591,7 @@ def __repr__(self) -> str:
564591 r .append (f"target_mode={ self .mode } " )
565592 return "<{}>" .format (" " .join (r ))
566593
567- def filter (self , image ) :
594+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
568595 from . import Image
569596
570597 return image .color_lut_3d (
0 commit comments