|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import React, { type ReactNode } from 'react' |
| 5 | +import { |
| 6 | + Svg, |
| 7 | + G, |
| 8 | + Path, |
| 9 | + Rect, |
| 10 | + Circle, |
| 11 | + Ellipse, |
| 12 | + Line, |
| 13 | + Polygon, |
| 14 | + Polyline, |
| 15 | + Text as PdfText, |
| 16 | + View, |
| 17 | +} from '@react-pdf/renderer' |
| 18 | +import { parse, type SvgElementNode, type SvgNode, type SvgTextNode } from 'svg-parser' |
| 19 | +import { mathjax } from 'mathjax-full/js/mathjax.js' |
| 20 | +import { TeX } from 'mathjax-full/js/input/tex.js' |
| 21 | +import { AllPackages } from 'mathjax-full/js/input/tex/AllPackages.js' |
| 22 | +import { SVG } from 'mathjax-full/js/output/svg.js' |
| 23 | +import { liteAdaptor } from 'mathjax-full/js/adaptors/liteAdaptor.js' |
| 24 | +import { RegisterHTMLHandler } from 'mathjax-full/js/handlers/html.js' |
| 25 | + |
| 26 | +const EX_TO_PT = 4.3 |
| 27 | +const PX_TO_PT = 0.75 |
| 28 | +const renderCache = new Map<string, string>() |
| 29 | + |
| 30 | +const adaptor = liteAdaptor() |
| 31 | +RegisterHTMLHandler(adaptor) |
| 32 | + |
| 33 | +const tex = new TeX({ packages: AllPackages }) |
| 34 | +const svg = new SVG({ fontCache: 'none' }) |
| 35 | +const html = mathjax.document('', { |
| 36 | + InputJax: tex, |
| 37 | + OutputJax: svg, |
| 38 | +}) |
| 39 | + |
| 40 | +interface MathSvgProps { |
| 41 | + latex: string |
| 42 | + display?: boolean |
| 43 | + fontSize?: number |
| 44 | +} |
| 45 | + |
| 46 | +interface PdfSvgAttributes { |
| 47 | + [key: string]: string | number | undefined |
| 48 | +} |
| 49 | + |
| 50 | +function renderMathToSvgMarkup(latex: string, display: boolean): string { |
| 51 | + const cacheKey = `${display ? 'display' : 'inline'}:${latex}` |
| 52 | + const cached = renderCache.get(cacheKey) |
| 53 | + |
| 54 | + if (cached) { |
| 55 | + return cached |
| 56 | + } |
| 57 | + |
| 58 | + const node = html.convert(latex, { display }) |
| 59 | + const markup = adaptor.outerHTML(node) |
| 60 | + renderCache.set(cacheKey, markup) |
| 61 | + |
| 62 | + return markup |
| 63 | +} |
| 64 | + |
| 65 | +function extractSvgNode(markup: string): SvgElementNode | null { |
| 66 | + const root = parse(markup) |
| 67 | + const stack: SvgNode[] = [...root.children] |
| 68 | + |
| 69 | + while (stack.length > 0) { |
| 70 | + const current = stack.shift() |
| 71 | + |
| 72 | + if (!current || current.type !== 'element') { |
| 73 | + continue |
| 74 | + } |
| 75 | + |
| 76 | + if (current.tagName === 'svg') { |
| 77 | + return current |
| 78 | + } |
| 79 | + |
| 80 | + stack.unshift(...current.children) |
| 81 | + } |
| 82 | + |
| 83 | + return null |
| 84 | +} |
| 85 | + |
| 86 | +function toPoints(value: string | undefined, fontSize: number): number | undefined { |
| 87 | + if (!value) return undefined |
| 88 | + |
| 89 | + const trimmed = value.trim() |
| 90 | + if (!trimmed) return undefined |
| 91 | + |
| 92 | + const numericValue = Number.parseFloat(trimmed) |
| 93 | + if (Number.isNaN(numericValue)) return undefined |
| 94 | + |
| 95 | + if (trimmed.endsWith('ex')) return numericValue * EX_TO_PT |
| 96 | + if (trimmed.endsWith('em')) return numericValue * fontSize |
| 97 | + if (trimmed.endsWith('px')) return numericValue * PX_TO_PT |
| 98 | + |
| 99 | + return numericValue |
| 100 | +} |
| 101 | + |
| 102 | +function toCamelCase(value: string): string { |
| 103 | + return value.replace(/-([a-z])/g, (_, char: string) => char.toUpperCase()) |
| 104 | +} |
| 105 | + |
| 106 | +function parseStyleAttribute(style: string): Record<string, string> { |
| 107 | + return style |
| 108 | + .split(';') |
| 109 | + .map((entry) => entry.trim()) |
| 110 | + .filter(Boolean) |
| 111 | + .reduce<Record<string, string>>((acc, entry) => { |
| 112 | + const [name, ...rest] = entry.split(':') |
| 113 | + if (!name || rest.length === 0) return acc |
| 114 | + |
| 115 | + acc[toCamelCase(name.trim())] = rest.join(':').trim() |
| 116 | + return acc |
| 117 | + }, {}) |
| 118 | +} |
| 119 | + |
| 120 | +function getPdfSvgAttributes(node: SvgElementNode): PdfSvgAttributes { |
| 121 | + const attributes: PdfSvgAttributes = {} |
| 122 | + const entries = Object.entries(node.properties) |
| 123 | + |
| 124 | + for (const [rawName, rawValue] of entries) { |
| 125 | + if (rawValue == null) continue |
| 126 | + if ( |
| 127 | + rawName === 'xmlns' || |
| 128 | + rawName === 'role' || |
| 129 | + rawName === 'focusable' || |
| 130 | + rawName === 'class' || |
| 131 | + rawName.startsWith('data-') |
| 132 | + ) { |
| 133 | + continue |
| 134 | + } |
| 135 | + |
| 136 | + if (rawName === 'style') { |
| 137 | + Object.assign(attributes, parseStyleAttribute(rawValue)) |
| 138 | + continue |
| 139 | + } |
| 140 | + |
| 141 | + const name = rawName === 'stroke-width' ? 'strokeWidth' : toCamelCase(rawName) |
| 142 | + attributes[name] = rawValue |
| 143 | + } |
| 144 | + |
| 145 | + if (attributes.fill === 'currentColor') { |
| 146 | + attributes.fill = '#111111' |
| 147 | + } |
| 148 | + |
| 149 | + if (attributes.stroke === 'currentColor') { |
| 150 | + attributes.stroke = '#111111' |
| 151 | + } |
| 152 | + |
| 153 | + return attributes |
| 154 | +} |
| 155 | + |
| 156 | +function renderSvgChild(node: SvgElementNode | SvgTextNode, key: string): ReactNode { |
| 157 | + if (node.type === 'text') { |
| 158 | + const text = node.value.trim() |
| 159 | + return text ? <PdfText key={key}>{text}</PdfText> : null |
| 160 | + } |
| 161 | + |
| 162 | + const attributes = getPdfSvgAttributes(node) |
| 163 | + const children = node.children.map((child, index) => renderSvgChild(child, `${key}-${index}`)) |
| 164 | + |
| 165 | + switch (node.tagName) { |
| 166 | + case 'g': |
| 167 | + return ( |
| 168 | + <G key={key} {...attributes}> |
| 169 | + {children} |
| 170 | + </G> |
| 171 | + ) |
| 172 | + case 'path': |
| 173 | + return <Path key={key} d={String(attributes.d ?? '')} {...attributes} /> |
| 174 | + case 'rect': |
| 175 | + return ( |
| 176 | + <Rect |
| 177 | + key={key} |
| 178 | + width={String(attributes.width ?? 0)} |
| 179 | + height={String(attributes.height ?? 0)} |
| 180 | + x={attributes.x} |
| 181 | + y={attributes.y} |
| 182 | + rx={attributes.rx} |
| 183 | + ry={attributes.ry} |
| 184 | + {...attributes} |
| 185 | + /> |
| 186 | + ) |
| 187 | + case 'circle': |
| 188 | + return <Circle key={key} r={String(attributes.r ?? 0)} cx={attributes.cx} cy={attributes.cy} {...attributes} /> |
| 189 | + case 'ellipse': |
| 190 | + return ( |
| 191 | + <Ellipse |
| 192 | + key={key} |
| 193 | + rx={String(attributes.rx ?? 0)} |
| 194 | + ry={String(attributes.ry ?? 0)} |
| 195 | + cx={attributes.cx} |
| 196 | + cy={attributes.cy} |
| 197 | + {...attributes} |
| 198 | + /> |
| 199 | + ) |
| 200 | + case 'line': |
| 201 | + return ( |
| 202 | + <Line |
| 203 | + key={key} |
| 204 | + x1={String(attributes.x1 ?? 0)} |
| 205 | + y1={String(attributes.y1 ?? 0)} |
| 206 | + x2={String(attributes.x2 ?? 0)} |
| 207 | + y2={String(attributes.y2 ?? 0)} |
| 208 | + {...attributes} |
| 209 | + /> |
| 210 | + ) |
| 211 | + case 'polygon': |
| 212 | + return <Polygon key={key} points={String(attributes.points ?? '')} {...attributes} /> |
| 213 | + case 'polyline': |
| 214 | + return <Polyline key={key} points={String(attributes.points ?? '')} {...attributes} /> |
| 215 | + default: |
| 216 | + return null |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +export const MathSvg: React.FC<MathSvgProps> = ({ latex, display = false, fontSize = 10 }) => { |
| 221 | + try { |
| 222 | + const svgMarkup = renderMathToSvgMarkup(latex, display) |
| 223 | + const svgNode = extractSvgNode(svgMarkup) |
| 224 | + |
| 225 | + if (!svgNode) { |
| 226 | + return <PdfText>{latex}</PdfText> |
| 227 | + } |
| 228 | + |
| 229 | + const width = toPoints(svgNode.properties.width, fontSize) |
| 230 | + const height = toPoints(svgNode.properties.height, fontSize) |
| 231 | + const inlineStyle = getPdfSvgAttributes(svgNode).verticalAlign |
| 232 | + const verticalAlign = typeof inlineStyle === 'string' ? toPoints(inlineStyle, fontSize) ?? 0 : 0 |
| 233 | + const children = svgNode.children.map((child, index) => renderSvgChild(child, `svg-${index}`)) |
| 234 | + |
| 235 | + return ( |
| 236 | + <View |
| 237 | + style={ |
| 238 | + display |
| 239 | + ? { alignItems: 'center', marginTop: 8, marginBottom: 10 } |
| 240 | + : { marginLeft: 1, marginRight: 1, marginBottom: verticalAlign } |
| 241 | + } |
| 242 | + > |
| 243 | + <Svg |
| 244 | + width={width ?? fontSize * 2} |
| 245 | + height={height ?? fontSize} |
| 246 | + viewBox={svgNode.properties.viewBox} |
| 247 | + preserveAspectRatio={svgNode.properties.preserveAspectRatio} |
| 248 | + > |
| 249 | + {children} |
| 250 | + </Svg> |
| 251 | + </View> |
| 252 | + ) |
| 253 | + } catch { |
| 254 | + return <PdfText>{latex}</PdfText> |
| 255 | + } |
| 256 | +} |
0 commit comments