11from __future__ import annotations
22
3- from typing import Any # undone
3+ from typing import Any , NamedTuple
44
55import pytest
66
1010 assert_deep_equal ,
1111 assert_image_equal ,
1212 hopper ,
13+ is_big_endian ,
1314)
1415
15- pyarrow = pytest .importorskip ("pyarrow" , reason = "PyArrow not installed" )
16+ TYPE_CHECKING = False
17+ if TYPE_CHECKING :
18+ import pyarrow
19+ else :
20+ pyarrow = pytest .importorskip ("pyarrow" , reason = "PyArrow not installed" )
1621
1722TEST_IMAGE_SIZE = (10 , 10 )
1823
1924
2025def _test_img_equals_pyarray (
21- img : Image .Image , arr : Any , mask : list [int ] | None
26+ img : Image .Image , arr : Any , mask : list [int ] | None , elts_per_pixel : int = 1
2227) -> None :
23- assert img .height * img .width == len (arr )
28+ assert img .height * img .width * elts_per_pixel == len (arr )
2429 px = img .load ()
2530 assert px is not None
31+ if elts_per_pixel > 1 and mask is None :
32+ # have to do element-wise comparison when we're comparing
33+ # flattened r,g,b,a to a pixel.
34+ mask = list (range (elts_per_pixel ))
2635 for x in range (0 , img .size [0 ], int (img .size [0 ] / 10 )):
2736 for y in range (0 , img .size [1 ], int (img .size [1 ] / 10 )):
2837 if mask :
38+ pixel = px [x , y ]
39+ assert isinstance (pixel , tuple )
2940 for ix , elt in enumerate (mask ):
30- pixel = px [x , y ]
31- assert isinstance (pixel , tuple )
32- assert pixel [ix ] == arr [y * img .width + x ].as_py ()[elt ]
41+ if elts_per_pixel == 1 :
42+ assert pixel [ix ] == arr [y * img .width + x ].as_py ()[elt ]
43+ else :
44+ assert (
45+ pixel [ix ]
46+ == arr [(y * img .width + x ) * elts_per_pixel + elt ].as_py ()
47+ )
3348 else :
3449 assert_deep_equal (px [x , y ], arr [y * img .width + x ].as_py ())
3550
3651
52+ def _test_img_equals_int32_pyarray (
53+ img : Image .Image , arr : Any , mask : list [int ] | None , elts_per_pixel : int = 1
54+ ) -> None :
55+ assert img .height * img .width * elts_per_pixel == len (arr )
56+ px = img .load ()
57+ assert px is not None
58+ if mask is None :
59+ # have to do element-wise comparison when we're comparing
60+ # flattened rgba in an uint32 to a pixel.
61+ mask = list (range (elts_per_pixel ))
62+ for x in range (0 , img .size [0 ], int (img .size [0 ] / 10 )):
63+ for y in range (0 , img .size [1 ], int (img .size [1 ] / 10 )):
64+ pixel = px [x , y ]
65+ assert isinstance (pixel , tuple )
66+ arr_pixel_int = arr [y * img .width + x ].as_py ()
67+ arr_pixel_tuple = (
68+ arr_pixel_int % 256 ,
69+ (arr_pixel_int // 256 ) % 256 ,
70+ (arr_pixel_int // 256 ** 2 ) % 256 ,
71+ (arr_pixel_int // 256 ** 3 ),
72+ )
73+ if is_big_endian ():
74+ arr_pixel_tuple = arr_pixel_tuple [::- 1 ]
75+
76+ for ix , elt in enumerate (mask ):
77+ assert pixel [ix ] == arr_pixel_tuple [elt ]
78+
79+
3780# really hard to get a non-nullable list type
3881fl_uint8_4_type = pyarrow .field (
3982 "_" , pyarrow .list_ (pyarrow .field ("_" , pyarrow .uint8 ()).with_nullable (False ), 4 )
@@ -55,14 +98,14 @@ def _test_img_equals_pyarray(
5598 ("HSV" , fl_uint8_4_type , [0 , 1 , 2 ]),
5699 ),
57100)
58- def test_to_array (mode : str , dtype : Any , mask : list [int ] | None ) -> None :
101+ def test_to_array (mode : str , dtype : pyarrow . DataType , mask : list [int ] | None ) -> None :
59102 img = hopper (mode )
60103
61104 # Resize to non-square
62105 img = img .crop ((3 , 0 , 124 , 127 ))
63106 assert img .size == (121 , 127 )
64107
65- arr = pyarrow .array (img )
108+ arr = pyarrow .array (img ) # type: ignore[call-overload]
66109 _test_img_equals_pyarray (img , arr , mask )
67110 assert arr .type == dtype
68111
@@ -79,8 +122,8 @@ def test_lifetime() -> None:
79122
80123 img = hopper ("L" )
81124
82- arr_1 = pyarrow .array (img )
83- arr_2 = pyarrow .array (img )
125+ arr_1 = pyarrow .array (img ) # type: ignore[call-overload]
126+ arr_2 = pyarrow .array (img ) # type: ignore[call-overload]
84127
85128 del img
86129
@@ -97,8 +140,8 @@ def test_lifetime2() -> None:
97140
98141 img = hopper ("L" )
99142
100- arr_1 = pyarrow .array (img )
101- arr_2 = pyarrow .array (img )
143+ arr_1 = pyarrow .array (img ) # type: ignore[call-overload]
144+ arr_2 = pyarrow .array (img ) # type: ignore[call-overload]
102145
103146 assert arr_1 .sum ().as_py () > 0
104147 del arr_1
@@ -110,3 +153,94 @@ def test_lifetime2() -> None:
110153 px = img2 .load ()
111154 assert px # make mypy happy
112155 assert isinstance (px [0 , 0 ], int )
156+
157+
158+ class DataShape (NamedTuple ):
159+ dtype : pyarrow .DataType
160+ # Strictly speaking, elt should be a pixel or pixel component, so
161+ # list[uint8][4], float, int, uint32, uint8, etc. But more
162+ # correctly, it should be exactly the dtype from the line above.
163+ elt : Any
164+ elts_per_pixel : int
165+
166+
167+ UINT_ARR = DataShape (
168+ dtype = fl_uint8_4_type ,
169+ elt = [1 , 2 , 3 , 4 ], # array of 4 uint8 per pixel
170+ elts_per_pixel = 1 , # only one array per pixel
171+ )
172+
173+ UINT = DataShape (
174+ dtype = pyarrow .uint8 (),
175+ elt = 3 , # one uint8,
176+ elts_per_pixel = 4 , # but repeated 4x per pixel
177+ )
178+
179+ UINT32 = DataShape (
180+ dtype = pyarrow .uint32 (),
181+ elt = 0xABCDEF45 , # one packed int, doesn't fit in a int32 > 0x80000000
182+ elts_per_pixel = 1 , # one per pixel
183+ )
184+
185+ INT32 = DataShape (
186+ dtype = pyarrow .uint32 (),
187+ elt = 0x12CDEF45 , # one packed int
188+ elts_per_pixel = 1 , # one per pixel
189+ )
190+
191+
192+ @pytest .mark .parametrize (
193+ "mode, data_tp, mask" ,
194+ (
195+ ("L" , DataShape (pyarrow .uint8 (), 3 , 1 ), None ),
196+ ("I" , DataShape (pyarrow .int32 (), 1 << 24 , 1 ), None ),
197+ ("F" , DataShape (pyarrow .float32 (), 3.14159 , 1 ), None ),
198+ ("LA" , UINT_ARR , [0 , 3 ]),
199+ ("LA" , UINT , [0 , 3 ]),
200+ ("RGB" , UINT_ARR , [0 , 1 , 2 ]),
201+ ("RGBA" , UINT_ARR , None ),
202+ ("CMYK" , UINT_ARR , None ),
203+ ("YCbCr" , UINT_ARR , [0 , 1 , 2 ]),
204+ ("HSV" , UINT_ARR , [0 , 1 , 2 ]),
205+ ("RGB" , UINT , [0 , 1 , 2 ]),
206+ ("RGBA" , UINT , None ),
207+ ("CMYK" , UINT , None ),
208+ ("YCbCr" , UINT , [0 , 1 , 2 ]),
209+ ("HSV" , UINT , [0 , 1 , 2 ]),
210+ ),
211+ )
212+ def test_fromarray (mode : str , data_tp : DataShape , mask : list [int ] | None ) -> None :
213+ (dtype , elt , elts_per_pixel ) = data_tp
214+
215+ ct_pixels = TEST_IMAGE_SIZE [0 ] * TEST_IMAGE_SIZE [1 ]
216+ arr = pyarrow .array ([elt ] * (ct_pixels * elts_per_pixel ), type = dtype )
217+ img = Image .fromarrow (arr , mode , TEST_IMAGE_SIZE )
218+
219+ _test_img_equals_pyarray (img , arr , mask , elts_per_pixel )
220+
221+
222+ @pytest .mark .parametrize (
223+ "mode, data_tp, mask" ,
224+ (
225+ ("LA" , UINT32 , [0 , 3 ]),
226+ ("RGB" , UINT32 , [0 , 1 , 2 ]),
227+ ("RGBA" , UINT32 , None ),
228+ ("CMYK" , UINT32 , None ),
229+ ("YCbCr" , UINT32 , [0 , 1 , 2 ]),
230+ ("HSV" , UINT32 , [0 , 1 , 2 ]),
231+ ("LA" , INT32 , [0 , 3 ]),
232+ ("RGB" , INT32 , [0 , 1 , 2 ]),
233+ ("RGBA" , INT32 , None ),
234+ ("CMYK" , INT32 , None ),
235+ ("YCbCr" , INT32 , [0 , 1 , 2 ]),
236+ ("HSV" , INT32 , [0 , 1 , 2 ]),
237+ ),
238+ )
239+ def test_from_int32array (mode : str , data_tp : DataShape , mask : list [int ] | None ) -> None :
240+ (dtype , elt , elts_per_pixel ) = data_tp
241+
242+ ct_pixels = TEST_IMAGE_SIZE [0 ] * TEST_IMAGE_SIZE [1 ]
243+ arr = pyarrow .array ([elt ] * (ct_pixels * elts_per_pixel ), type = dtype )
244+ img = Image .fromarrow (arr , mode , TEST_IMAGE_SIZE )
245+
246+ _test_img_equals_int32_pyarray (img , arr , mask , elts_per_pixel )
0 commit comments