Skip to content

Commit 5e3dc40

Browse files
committed
Do not count frames on image open
1 parent 762235c commit 5e3dc40

File tree

3 files changed

+42
-60
lines changed

3 files changed

+42
-60
lines changed

Tests/test_file_jxl_metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ class JpegXlDecoder:
9797
def __init__(self, b: bytes) -> None:
9898
pass
9999

100-
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int, int]:
101-
return ((1, 1), "L", 0, 0, 0, 0, 0)
100+
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int]:
101+
return ((1, 1), "L", 0, 0, 0, 0)
102102

103103
def get_icc(self) -> None:
104104
pass

src/PIL/JpegXlImagePlugin.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,6 @@
1313
SUPPORTED = False
1414

1515

16-
## Future idea:
17-
## it's not known how many frames an animated image has
18-
## by default, _jxl_decoder_new will iterate over all frames without decoding them
19-
## then libjxl decoder is rewinded and we're ready to decode frame by frame
20-
## if OPEN_COUNTS_FRAMES is False, n_frames will be None until the last frame is decoded
21-
## it only applies to animated jpeg xl images
22-
# OPEN_COUNTS_FRAMES = True
23-
24-
2516
def _accept(prefix: bytes) -> bool | str:
2617
is_jxl = prefix.startswith(
2718
(b"\xff\x0a", b"\x00\x00\x00\x0c\x4a\x58\x4c\x20\x0d\x0a\x87\x0a")
@@ -34,8 +25,9 @@ def _accept(prefix: bytes) -> bool | str:
3425
class JpegXlImageFile(ImageFile.ImageFile):
3526
format = "JPEG XL"
3627
format_description = "JPEG XL image"
37-
__loaded = 0
28+
__loaded = -1
3829
__logical_frame = 0
30+
__physical_frame = 0
3931

4032
def _open(self) -> None:
4133
self._decoder = _jpegxl.JpegXlDecoder(self.fp.read())
@@ -47,16 +39,10 @@ def _open(self) -> None:
4739
tps_num,
4840
tps_denom,
4941
self.info["loop"],
50-
n_frames,
5142
) = self._decoder.get_info()
5243

53-
self._tps_dur_secs = 1
54-
self.n_frames: int | None = 1
55-
if self.is_animated:
56-
self.n_frames = None
57-
if n_frames > 0:
58-
self.n_frames = n_frames
59-
self._tps_dur_secs = tps_num / tps_denom
44+
self._n_frames = None if self.is_animated else 1
45+
self._tps_dur_secs = tps_num / tps_denom if tps_denom != 0 else 1
6046

6147
# TODO: handle libjxl time codes
6248
self.__timestamp = 0
@@ -72,7 +58,14 @@ def _open(self) -> None:
7258
if xmp := self._decoder.get_xmp():
7359
self.info["xmp"] = xmp
7460

75-
self._rewind()
61+
@property
62+
def n_frames(self) -> int:
63+
if self._n_frames is None:
64+
current = self.tell()
65+
self._n_frames = current + self._decoder.get_frames_left()
66+
self.seek(current)
67+
68+
return self._n_frames
7669

7770
def _get_next(self) -> tuple[bytes, float, float]:
7871
# Get next frame
@@ -85,9 +78,9 @@ def _get_next(self) -> tuple[bytes, float, float]:
8578
raise EOFError(msg)
8679

8780
data, tps_duration, is_last = next_frame
88-
if is_last and self.n_frames is None:
81+
if is_last and self._n_frames is None:
8982
# libjxl said this frame is the last one
90-
self.n_frames = self.__physical_frame
83+
self._n_frames = self.__physical_frame
9184

9285
# duration in milliseconds
9386
duration = 1000 * tps_duration * (1 / self._tps_dur_secs)
@@ -96,24 +89,22 @@ def _get_next(self) -> tuple[bytes, float, float]:
9689

9790
return data, timestamp, duration
9891

99-
def _rewind(self, hard: bool = False) -> None:
100-
if hard:
101-
self._decoder.rewind()
102-
self.__physical_frame = 0
103-
self.__loaded = -1
104-
self.__timestamp = 0
105-
10692
def _seek(self, frame: int) -> None:
10793
if frame == self.__physical_frame:
10894
return # Nothing to do
10995
if frame < self.__physical_frame:
11096
# also rewind libjxl decoder instance
111-
self._rewind(hard=True)
97+
self._decoder.rewind()
98+
self.__physical_frame = 0
99+
self.__loaded = -1
100+
self.__timestamp = 0
112101

113102
while self.__physical_frame < frame:
114103
self._get_next() # Advance to the requested frame
115104

116105
def seek(self, frame: int) -> None:
106+
if self._n_frames is None:
107+
self.n_frames
117108
if not self._seek_check(frame):
118109
return
119110

src/_jpegxl.c

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ typedef struct {
9898
JxlBasicInfo basic_info;
9999
JxlPixelFormat pixel_format;
100100

101-
Py_ssize_t n_frames;
102-
103101
char *mode;
104102
} JpegXlDecoderObject;
105103

@@ -166,27 +164,26 @@ _jxl_decoder_rewind(PyObject *self) {
166164
Py_RETURN_NONE;
167165
}
168166

169-
bool
170-
_jxl_decoder_count_frames(PyObject *self) {
171-
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
172-
173-
decp->n_frames = 0;
167+
PyObject *
168+
_jxl_decoder_get_frames_left(PyObject *self) {
169+
int frames_left = 0;
174170

175171
// count all JXL_DEC_NEED_IMAGE_OUT_BUFFER events
172+
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
176173
while (decp->status != JXL_DEC_SUCCESS) {
177174
decp->status = JxlDecoderProcessInput(decp->decoder);
178175

179176
if (decp->status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) {
180177
if (JxlDecoderSkipCurrentFrame(decp->decoder) != JXL_DEC_SUCCESS) {
181-
return false;
178+
PyErr_SetString(PyExc_OSError, "Error when counting frames");
179+
break;
182180
}
183-
decp->n_frames++;
181+
frames_left++;
184182
}
185183
}
184+
JxlDecoderRewind(decp->decoder);
186185

187-
_jxl_decoder_rewind((PyObject *)decp);
188-
189-
return true;
186+
return Py_BuildValue("i", frames_left);
190187
}
191188

192189
PyObject *
@@ -206,7 +203,6 @@ _jxl_decoder_new(PyObject *self, PyObject *args) {
206203
decp->jxl_exif_len = 0;
207204
decp->jxl_xmp = NULL;
208205
decp->jxl_xmp_len = 0;
209-
decp->n_frames = 0;
210206

211207
// used for printing more detailed error messages
212208
char *jxl_call_name;
@@ -371,14 +367,6 @@ _jxl_decoder_new(PyObject *self, PyObject *args) {
371367
goto end_with_custom_error;
372368
}
373369

374-
if (decp->basic_info.have_animation) {
375-
// get frame count by iterating over image out events
376-
if (!_jxl_decoder_count_frames((PyObject *)decp)) {
377-
PyErr_SetString(PyExc_OSError, "something went wrong when counting frames");
378-
goto end_with_custom_error;
379-
}
380-
}
381-
382370
return (PyObject *)decp;
383371

384372
// on success we should never reach here
@@ -410,15 +398,14 @@ _jxl_decoder_get_info(PyObject *self) {
410398
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
411399

412400
return Py_BuildValue(
413-
"(II)sOIIII",
401+
"(II)sOIII",
414402
decp->basic_info.xsize,
415403
decp->basic_info.ysize,
416404
decp->mode,
417405
decp->basic_info.have_animation ? Py_True : Py_False,
418406
decp->basic_info.animation.tps_numerator,
419407
decp->basic_info.animation.tps_denominator,
420-
decp->basic_info.animation.num_loops,
421-
decp->n_frames
408+
decp->basic_info.animation.num_loops
422409
);
423410
}
424411

@@ -432,6 +419,10 @@ _jxl_decoder_get_next(PyObject *self) {
432419
char *jxl_call_name;
433420

434421
// process events until next frame output is ready
422+
if (decp->status == JXL_DEC_FRAME) {
423+
decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr);
424+
_JXL_CHECK("JxlDecoderGetFrameHeader");
425+
}
435426
while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) {
436427
decp->status = JxlDecoderProcessInput(decp->decoder);
437428

@@ -444,14 +435,10 @@ _jxl_decoder_get_next(PyObject *self) {
444435
if (decp->status == JXL_DEC_NEED_MORE_INPUT) {
445436
_jxl_decoder_set_input((PyObject *)decp);
446437
_JXL_CHECK("JxlDecoderSetInput")
447-
continue;
448-
}
449-
450-
if (decp->status == JXL_DEC_FRAME) {
438+
} else if (decp->status == JXL_DEC_FRAME) {
451439
// decode frame header
452440
decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr);
453441
_JXL_CHECK("JxlDecoderGetFrameHeader");
454-
continue;
455442
}
456443
}
457444

@@ -573,6 +560,10 @@ static struct PyMethodDef _jpegxl_decoder_methods[] = {
573560
{"get_icc", (PyCFunction)_jxl_decoder_get_icc, METH_NOARGS, "get_icc"},
574561
{"get_exif", (PyCFunction)_jxl_decoder_get_exif, METH_NOARGS, "get_exif"},
575562
{"get_xmp", (PyCFunction)_jxl_decoder_get_xmp, METH_NOARGS, "get_xmp"},
563+
{"get_frames_left",
564+
(PyCFunction)_jxl_decoder_get_frames_left,
565+
METH_NOARGS,
566+
"get_frames_left"},
576567
{"rewind", (PyCFunction)_jxl_decoder_rewind, METH_NOARGS, "rewind"},
577568
{NULL, NULL} /* sentinel */
578569
};

0 commit comments

Comments
 (0)