Skip to content
225 changes: 203 additions & 22 deletions Modules/_json.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ typedef struct _PyEncoderObject {
char sort_keys;
char skipkeys;
int allow_nan;
PyCFunction fast_encode;
int (*fast_encode)(PyUnicodeWriter *, PyObject *);
} PyEncoderObject;

#define PyEncoderObject_CAST(op) ((PyEncoderObject *)(op))
Expand Down Expand Up @@ -102,8 +102,10 @@ static PyObject *
_encoded_const(PyObject *obj);
static void
raise_errmsg(const char *msg, PyObject *s, Py_ssize_t end);
static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj);
static int
_steal_accumulate(PyUnicodeWriter *writer, PyObject *stolen);
static int
encoder_write_string(PyEncoderObject *s, PyUnicodeWriter *writer, PyObject *obj);
static PyObject *
encoder_encode_float(PyEncoderObject *s, PyObject *obj);

Expand Down Expand Up @@ -209,6 +211,80 @@ ascii_escape_unicode(PyObject *pystr)
return rval;
}

static int
write_escaped_ascii(PyUnicodeWriter *writer, PyObject *pystr)
{
/* Take a PyUnicode pystr and return a new ASCII-only escaped PyUnicode */
Py_ssize_t i;
Py_ssize_t input_chars;
Py_ssize_t output_size;
Py_ssize_t chars;
PyObject *rval;
const void *input;
Py_UCS1 *output;
int kind;

input_chars = PyUnicode_GET_LENGTH(pystr);
input = PyUnicode_DATA(pystr);
kind = PyUnicode_KIND(pystr);

/* Compute the output size */
for (i = 0, output_size = 2; i < input_chars; i++) {
Py_UCS4 c = PyUnicode_READ(kind, input, i);
Py_ssize_t d;
if (S_CHAR(c)) {
d = 1;
}
else {
switch(c) {
case '\\': case '"': case '\b': case '\f':
case '\n': case '\r': case '\t':
d = 2; break;
default:
d = c >= 0x10000 ? 12 : 6;
}
}
if (output_size > PY_SSIZE_T_MAX - d) {
PyErr_SetString(PyExc_OverflowError, "string is too long to escape");
return -1;
}
output_size += d;
}

if (output_size == input_chars + 2) {
/* No need to escape anything */
if (PyUnicodeWriter_WriteChar(writer, '"') < 0) {
return -1;
}
if (PyUnicodeWriter_WriteStr(writer, pystr) < 0) {
return -1;
}
return PyUnicodeWriter_WriteChar(writer, '"');
}

rval = PyUnicode_New(output_size, 127);
if (rval == NULL) {
return -1;
}
output = PyUnicode_1BYTE_DATA(rval);
chars = 0;
output[chars++] = '"';
for (i = 0; i < input_chars; i++) {
Py_UCS4 c = PyUnicode_READ(kind, input, i);
if (S_CHAR(c)) {
output[chars++] = c;
}
else {
chars = ascii_escape_unichar(c, output, chars);
}
}
output[chars++] = '"';
#ifdef Py_DEBUG
assert(_PyUnicode_CheckConsistency(rval, 1));
#endif
return _steal_accumulate(writer, rval);
}

static PyObject *
escape_unicode(PyObject *pystr)
{
Expand Down Expand Up @@ -303,6 +379,111 @@ escape_unicode(PyObject *pystr)
return rval;
}

static int
write_escaped_unicode(PyUnicodeWriter *writer, PyObject *pystr)
{
/* Take a PyUnicode pystr and return a new escaped PyUnicode */
Py_ssize_t i;
Py_ssize_t input_chars;
Py_ssize_t output_size;
Py_ssize_t chars;
PyObject *rval;
const void *input;
int kind;
Py_UCS4 maxchar;

maxchar = PyUnicode_MAX_CHAR_VALUE(pystr);
input_chars = PyUnicode_GET_LENGTH(pystr);
input = PyUnicode_DATA(pystr);
kind = PyUnicode_KIND(pystr);

/* Compute the output size */
for (i = 0, output_size = 2; i < input_chars; i++) {
Py_UCS4 c = PyUnicode_READ(kind, input, i);
Py_ssize_t d;
switch (c) {
case '\\': case '"': case '\b': case '\f':
case '\n': case '\r': case '\t':
d = 2;
break;
default:
if (c <= 0x1f)
d = 6;
else
d = 1;
}
if (output_size > PY_SSIZE_T_MAX - d) {
PyErr_SetString(PyExc_OverflowError, "string is too long to escape");
return -1;
}
output_size += d;
}

if (output_size == input_chars + 2) {
/* No need to escape anything */
if (PyUnicodeWriter_WriteChar(writer, '"') < 0) {
return -1;
}
if (PyUnicodeWriter_WriteStr(writer, pystr) < 0) {
return -1;
}
return PyUnicodeWriter_WriteChar(writer, '"');
}

rval = PyUnicode_New(output_size, maxchar);
if (rval == NULL)
return -1;

kind = PyUnicode_KIND(rval);

#define ENCODE_OUTPUT do { \
chars = 0; \
output[chars++] = '"'; \
for (i = 0; i < input_chars; i++) { \
Py_UCS4 c = PyUnicode_READ(kind, input, i); \
switch (c) { \
case '\\': output[chars++] = '\\'; output[chars++] = c; break; \
case '"': output[chars++] = '\\'; output[chars++] = c; break; \
case '\b': output[chars++] = '\\'; output[chars++] = 'b'; break; \
case '\f': output[chars++] = '\\'; output[chars++] = 'f'; break; \
case '\n': output[chars++] = '\\'; output[chars++] = 'n'; break; \
case '\r': output[chars++] = '\\'; output[chars++] = 'r'; break; \
case '\t': output[chars++] = '\\'; output[chars++] = 't'; break; \
default: \
if (c <= 0x1f) { \
output[chars++] = '\\'; \
output[chars++] = 'u'; \
output[chars++] = '0'; \
output[chars++] = '0'; \
output[chars++] = Py_hexdigits[(c >> 4) & 0xf]; \
output[chars++] = Py_hexdigits[(c ) & 0xf]; \
} else { \
output[chars++] = c; \
} \
} \
} \
output[chars++] = '"'; \
} while (0)

if (kind == PyUnicode_1BYTE_KIND) {
Py_UCS1 *output = PyUnicode_1BYTE_DATA(rval);
ENCODE_OUTPUT;
} else if (kind == PyUnicode_2BYTE_KIND) {
Py_UCS2 *output = PyUnicode_2BYTE_DATA(rval);
ENCODE_OUTPUT;
} else {
Py_UCS4 *output = PyUnicode_4BYTE_DATA(rval);
assert(kind == PyUnicode_4BYTE_KIND);
ENCODE_OUTPUT;
}
#undef ENCODE_OUTPUT

#ifdef Py_DEBUG
assert(_PyUnicode_CheckConsistency(rval, 1));
#endif
return _steal_accumulate(writer, rval);
}

static void
raise_errmsg(const char *msg, PyObject *s, Py_ssize_t end)
{
Expand Down Expand Up @@ -1255,8 +1436,11 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)

if (PyCFunction_Check(s->encoder)) {
PyCFunction f = PyCFunction_GetFunction(s->encoder);
if (f == py_encode_basestring_ascii || f == py_encode_basestring) {
s->fast_encode = f;
if (f == py_encode_basestring_ascii) {
s->fast_encode = write_escaped_ascii;
}
else if (f == py_encode_basestring) {
s->fast_encode = write_escaped_unicode;
}
}

Expand Down Expand Up @@ -1437,24 +1621,27 @@ encoder_encode_float(PyEncoderObject *s, PyObject *obj)
return PyFloat_Type.tp_repr(obj);
}

static PyObject *
encoder_encode_string(PyEncoderObject *s, PyObject *obj)
static int
encoder_write_string(PyEncoderObject *s, PyUnicodeWriter *writer, PyObject *obj)
{
/* Return the JSON representation of a string */
PyObject *encoded;

if (s->fast_encode) {
return s->fast_encode(NULL, obj);
return s->fast_encode(writer, obj);
}
encoded = PyObject_CallOneArg(s->encoder, obj);
if (encoded != NULL && !PyUnicode_Check(encoded)) {
if (encoded == NULL) {
return -1;
}
if (!PyUnicode_Check(encoded)) {
PyErr_Format(PyExc_TypeError,
"encoder() must return a string, not %.80s",
Py_TYPE(encoded)->tp_name);
Py_DECREF(encoded);
return NULL;
return -1;
}
return encoded;
return _steal_accumulate(writer, encoded);
}

static int
Expand Down Expand Up @@ -1485,10 +1672,7 @@ encoder_listencode_obj(PyEncoderObject *s, PyUnicodeWriter *writer,
return PyUnicodeWriter_WriteASCII(writer, "false", 5);
}
else if (PyUnicode_Check(obj)) {
PyObject *encoded = encoder_encode_string(s, obj);
if (encoded == NULL)
return -1;
return _steal_accumulate(writer, encoded);
return encoder_write_string(s, writer, obj);
}
else if (PyLong_Check(obj)) {
if (PyLong_CheckExact(obj)) {
Expand Down Expand Up @@ -1577,7 +1761,7 @@ encoder_encode_key_value(PyEncoderObject *s, PyUnicodeWriter *writer, bool *firs
PyObject *item_separator)
{
PyObject *keystr = NULL;
PyObject *encoded;
int rv;

if (PyUnicode_Check(key)) {
keystr = Py_NewRef(key);
Expand Down Expand Up @@ -1617,14 +1801,11 @@ encoder_encode_key_value(PyEncoderObject *s, PyUnicodeWriter *writer, bool *firs
}
}

encoded = encoder_encode_string(s, keystr);
rv = encoder_write_string(s, writer, keystr);
Py_DECREF(keystr);
if (encoded == NULL) {
return -1;
}

if (_steal_accumulate(writer, encoded) < 0) {
return -1;
if (rv < 0) {
return rv;
}
if (PyUnicodeWriter_WriteStr(writer, s->key_separator) < 0) {
return -1;
Expand Down
Loading