1919import abc
2020import functools
2121from types import ModuleType
22- from typing import Any , Sequence
22+ from typing import TYPE_CHECKING , Any , Callable , Sequence , cast
23+
24+ if TYPE_CHECKING :
25+ from . import _imaging
26+ from ._typing import NumpyArray
2327
2428
2529class Filter :
2630 @abc .abstractmethod
27- def filter (self , image ) :
31+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
2832 pass
2933
3034
@@ -33,7 +37,9 @@ class MultibandFilter(Filter):
3337
3438
3539class BuiltinFilter (MultibandFilter ):
36- def filter (self , image ):
40+ filterargs : tuple [Any , ...]
41+
42+ def filter (self , image : _imaging .ImagingCore ) -> _imaging .ImagingCore :
3743 if image .mode == "P" :
3844 msg = "cannot filter palette images"
3945 raise ValueError (msg )
@@ -91,7 +97,7 @@ def __init__(self, size: int, rank: int) -> None:
9197 self .size = size
9298 self .rank = rank
9399
94- def filter (self , image ) :
100+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
95101 if image .mode == "P" :
96102 msg = "cannot filter palette images"
97103 raise ValueError (msg )
@@ -158,7 +164,7 @@ class ModeFilter(Filter):
158164 def __init__ (self , size : int = 3 ) -> None :
159165 self .size = size
160166
161- def filter (self , image ) :
167+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
162168 return image .modefilter (self .size )
163169
164170
@@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
176182 def __init__ (self , radius : float | Sequence [float ] = 2 ) -> None :
177183 self .radius = radius
178184
179- def filter (self , image ) :
185+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
180186 xy = self .radius
181- if not isinstance (xy , (tuple , list )):
187+ if isinstance (xy , (int , float )):
182188 xy = (xy , xy )
183189 if xy == (0 , 0 ):
184190 return image .copy ()
@@ -208,9 +214,9 @@ def __init__(self, radius: float | Sequence[float]) -> None:
208214 raise ValueError (msg )
209215 self .radius = radius
210216
211- def filter (self , image ) :
217+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
212218 xy = self .radius
213- if not isinstance (xy , (tuple , list )):
219+ if isinstance (xy , (int , float )):
214220 xy = (xy , xy )
215221 if xy == (0 , 0 ):
216222 return image .copy ()
@@ -241,7 +247,7 @@ def __init__(
241247 self .percent = percent
242248 self .threshold = threshold
243249
244- def filter (self , image ) :
250+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
245251 return image .unsharp_mask (self .radius , self .percent , self .threshold )
246252
247253
@@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
387393 name = "Color 3D LUT"
388394
389395 def __init__ (
390- self , size , table , channels : int = 3 , target_mode : str | None = None , ** kwargs
391- ):
396+ self ,
397+ size : int | tuple [int , int , int ],
398+ table : Sequence [float ] | Sequence [Sequence [int ]] | NumpyArray ,
399+ channels : int = 3 ,
400+ target_mode : str | None = None ,
401+ ** kwargs : bool ,
402+ ) -> None :
392403 if channels not in (3 , 4 ):
393404 msg = "Only 3 or 4 output channels are supported"
394405 raise ValueError (msg )
@@ -410,15 +421,16 @@ def __init__(
410421 pass
411422
412423 if numpy and isinstance (table , numpy .ndarray ):
424+ numpy_table : NumpyArray = table
413425 if copy_table :
414- table = table .copy ()
426+ numpy_table = numpy_table .copy ()
415427
416- if table .shape in [
428+ if numpy_table .shape in [
417429 (items * channels ,),
418430 (items , channels ),
419431 (size [2 ], size [1 ], size [0 ], channels ),
420432 ]:
421- table = table .reshape (items * channels )
433+ table = numpy_table .reshape (items * channels )
422434 else :
423435 wrong_size = True
424436
@@ -428,15 +440,17 @@ def __init__(
428440
429441 # Convert to a flat list
430442 if table and isinstance (table [0 ], (list , tuple )):
431- table , raw_table = [], table
443+ raw_table = cast (Sequence [Sequence [int ]], table )
444+ flat_table : list [int ] = []
432445 for pixel in raw_table :
433446 if len (pixel ) != channels :
434447 msg = (
435448 "The elements of the table should "
436449 f"have a length of { channels } ."
437450 )
438451 raise ValueError (msg )
439- table .extend (pixel )
452+ flat_table .extend (pixel )
453+ table = flat_table
440454
441455 if wrong_size or len (table ) != items * channels :
442456 msg = (
@@ -449,23 +463,29 @@ def __init__(
449463 self .table = table
450464
451465 @staticmethod
452- def _check_size (size : Any ) -> list [ int ]:
466+ def _check_size (size : Any ) -> tuple [ int , int , int ]:
453467 try :
454468 _ , _ , _ = size
455469 except ValueError as e :
456470 msg = "Size should be either an integer or a tuple of three integers."
457471 raise ValueError (msg ) from e
458472 except TypeError :
459473 size = (size , size , size )
460- size = [ int (x ) for x in size ]
474+ size = tuple ( int (x ) for x in size )
461475 for size_1d in size :
462476 if not 2 <= size_1d <= 65 :
463477 msg = "Size should be in [2, 65] range."
464478 raise ValueError (msg )
465479 return size
466480
467481 @classmethod
468- def generate (cls , size , callback , channels = 3 , target_mode = None ):
482+ def generate (
483+ cls ,
484+ size : int | tuple [int , int , int ],
485+ callback : Callable [[float , float , float ], tuple [float , ...]],
486+ channels : int = 3 ,
487+ target_mode : str | None = None ,
488+ ) -> Color3DLUT :
469489 """Generates new LUT using provided callback.
470490
471491 :param size: Size of the table. Passed to the constructor.
@@ -482,7 +502,7 @@ def generate(cls, size, callback, channels=3, target_mode=None):
482502 msg = "Only 3 or 4 output channels are supported"
483503 raise ValueError (msg )
484504
485- table = [0 ] * (size_1d * size_2d * size_3d * channels )
505+ table : list [ float ] = [0 ] * (size_1d * size_2d * size_3d * channels )
486506 idx_out = 0
487507 for b in range (size_3d ):
488508 for g in range (size_2d ):
@@ -500,7 +520,13 @@ def generate(cls, size, callback, channels=3, target_mode=None):
500520 _copy_table = False ,
501521 )
502522
503- def transform (self , callback , with_normals = False , channels = None , target_mode = None ):
523+ def transform (
524+ self ,
525+ callback : Callable [..., tuple [float , ...]],
526+ with_normals : bool = False ,
527+ channels : int | None = None ,
528+ target_mode : str | None = None ,
529+ ) -> Color3DLUT :
504530 """Transforms the table values using provided callback and returns
505531 a new LUT with altered values.
506532
@@ -564,7 +590,7 @@ def __repr__(self) -> str:
564590 r .append (f"target_mode={ self .mode } " )
565591 return "<{}>" .format (" " .join (r ))
566592
567- def filter (self , image ) :
593+ def filter (self , image : _imaging . ImagingCore ) -> _imaging . ImagingCore :
568594 from . import Image
569595
570596 return image .color_lut_3d (
0 commit comments