Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 136 additions & 72 deletions yourbench/pipeline/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,10 @@
@module chunking
@author @sumukshashidhar

This module implements the "Semantic Chunking" stage of the YourBench pipeline.
It takes ingested and optionally summarized documents, partitions them into
multiple coherent segments (single-hop chunks), and optionally creates multi-hop
chunks by sampling and concatenating various single-hop segments.

Preserves semantic relationships among sentences by leveraging embeddings from a
transformer-based model (e.g., "intfloat/multilingual-e5-large-instruct"). This
stage helps downstream question generation avoid handling entire long documents
at once, improving coverage and reducing the risk of overlooking important but
less prominent content.
This module implements two modes of chunking for the YourBench pipeline:
1) "fast_chunking" (the new default), which chunks by purely length-based rules.
2) "semantic_chunking" (requires explicit config), which uses sentence embeddings
and a similarity threshold to decide chunk boundaries.

Usage:
------
Expand All @@ -24,18 +18,19 @@

The run(config) function:
1. Loads a dataset specified by the pipeline configuration.
2. Splits each document into single-hop chunks, guided by user-defined token
length constraints (l_min_tokens, l_max_tokens) and a similarity threshold
(tau_threshold).
3. Creates multi-hop chunks by sampling subsets of single-hop chunks and
concatenating them.
4. Computes optional readability and perplexity metrics for each chunk if debug
mode is enabled.
2. Depending on the configured chunking mode:
- fast_chunking (default): Chunks text solely based on maximum token length,
ignoring sentence similarity.
- semantic_chunking (requires pipeline.chunking.chunking_configuration.chunking_mode="semantic_chunking"):
Splits each document into single-hop chunks, guided by user-defined token
length constraints (l_min_tokens, l_max_tokens) and a similarity threshold (tau_threshold).
3. Creates multi-hop chunks by sampling subsets of single-hop chunks and concatenating them.
4. Computes optional readability and perplexity metrics for each chunk if debug mode is enabled.
5. Saves the dataset containing new columns:
- "chunks" (list of single-hop segments)
- "multihop_chunks" (list of multi-hop segment groups)
- "chunk_info_metrics" (various statistics)
- "chunking_model" (the model used for embeddings).
- "chunking_model" (the model used for embeddings; blank or default if fast_chunking).

Error Handling and Logging:
---------------------------
Expand Down Expand Up @@ -97,6 +92,7 @@ class ChunkingParameters:
h_min: int = 2
h_max: int = 3
num_multihops_factor: int = 2
chunking_mode: str = "fast_chunking" # "fast_chunking" or "semantic_chunking"


@dataclass
Expand Down Expand Up @@ -125,7 +121,8 @@ class ChunkInfoMetrics:
def _parse_chunking_parameters(config: Dict[str, Any]) -> ChunkingParameters:
"""
Extracts the chunking parameters from the config dictionary, falling back
to default values if keys are missing.
to default values if keys are missing. The chunking_mode defaults to
"fast_chunking" unless explicitly set to "semantic_chunking."
"""
chunking_params = config.get("pipeline", {}).get("chunking", {}).get("chunking_configuration", {})
return ChunkingParameters(
Expand All @@ -135,6 +132,7 @@ def _parse_chunking_parameters(config: Dict[str, Any]) -> ChunkingParameters:
h_min=chunking_params.get("h_min", 2),
h_max=chunking_params.get("h_max", 3),
num_multihops_factor=chunking_params.get("num_multihops_factor", 2),
chunking_mode=chunking_params.get("chunking_mode", "fast_chunking"),
)


Expand Down Expand Up @@ -173,6 +171,7 @@ def run(config: Dict[str, Any]) -> None:
h_min = params.h_min
h_max = params.h_max
num_multihops_factor = params.num_multihops_factor
chunking_mode = params.chunking_mode.lower().strip()

# Check debug setting
debug_mode: bool = config.get("settings", {}).get("debug", False)
Expand All @@ -185,27 +184,35 @@ def run(config: Dict[str, Any]) -> None:
local_perplexity_metric = _perplexity_metric
local_use_textstat = _use_textstat

# Load chunking model
try:
# Extract model name from config if available
model_name_list = config.get("model_roles", {}).get("chunking", [])
if model_name_list is None or len(model_name_list) == 0:
logger.info(
"No chunking model specified in config['model_roles']['chunking']. "
"Using default 'intfloat/multilingual-e5-large-instruct'."
)
model_name = "intfloat/multilingual-e5-large-instruct"
else:
model_name = model_name_list[0]

logger.info(f"Using chunking model: '{model_name}'")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
except Exception as model_error:
logger.error(f"Error loading tokenizer/model '{model_name}': {model_error}")
logger.warning("Chunking stage cannot proceed. Exiting.")
return
# We'll only load the chunking model if in semantic_chunking mode
tokenizer = None
model = None
device = "cpu"
model_name = "no_model_for_fast_chunking"

if chunking_mode == "semantic_chunking":
try:
# Extract model name from config if available
model_name_list = config.get("model_roles", {}).get("chunking", [])
if model_name_list is None or len(model_name_list) == 0:
logger.info(
"No chunking model specified in config['model_roles']['chunking']. "
"Using default 'intfloat/multilingual-e5-large-instruct'."
)
model_name = "intfloat/multilingual-e5-large-instruct"
else:
model_name = model_name_list[0]

logger.info(f"Using chunking model: '{model_name}'")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
except Exception as model_error:
logger.error(f"Error loading tokenizer/model '{model_name}': {model_error}")
logger.warning("Chunking stage cannot proceed with semantic_chunking. Exiting.")
return
else:
logger.info("Using fast_chunking mode: purely length-based chunking with no embeddings.")

# Prepare data structures
all_single_hop_chunks: list[list[SingleHopChunk]] = []
Expand Down Expand Up @@ -235,34 +242,42 @@ def run(config: Dict[str, Any]) -> None:
all_chunk_info_metrics.append([])
continue

# Compute embeddings for sentences
sentence_embeddings = _compute_embeddings(tokenizer, model, texts=sentences, device=device, max_len=512)

# Compute consecutive sentence similarities
consecutive_sims: list[float] = []
for sentence_index in range(len(sentences) - 1):
cos_sim = float(
F.cosine_similarity(
sentence_embeddings[sentence_index].unsqueeze(0),
sentence_embeddings[sentence_index + 1].unsqueeze(0),
dim=1,
)[0]
# Depending on the chunking mode:
if chunking_mode == "semantic_chunking":
# 1) Compute embeddings for sentences
sentence_embeddings = _compute_embeddings(tokenizer, model, texts=sentences, device=device, max_len=512)
# 2) Compute consecutive sentence similarities
consecutive_sims: list[float] = []
for sentence_index in range(len(sentences) - 1):
cos_sim = float(
F.cosine_similarity(
sentence_embeddings[sentence_index].unsqueeze(0),
sentence_embeddings[sentence_index + 1].unsqueeze(0),
dim=1,
)[0]
)
consecutive_sims.append(cos_sim)
if consecutive_sims:
all_similarities.append(consecutive_sims)

# 3) Create single-hop chunks with semantic logic
single_hop_chunks = _chunk_document_semantic(
sentences=sentences,
similarities=consecutive_sims,
l_min_tokens=l_min_tokens,
l_max_tokens=l_max_tokens,
tau=tau_threshold,
doc_id=doc_id,
)
else:
# Fast chunking: purely length-based
single_hop_chunks = _chunk_document_fast(
sentences=sentences,
l_max_tokens=l_max_tokens,
doc_id=doc_id,
)
consecutive_sims.append(cos_sim)
if consecutive_sims:
all_similarities.append(consecutive_sims)

# Create single-hop chunks
single_hop_chunks = _chunk_document(
sentences=sentences,
similarities=consecutive_sims,
l_min_tokens=l_min_tokens,
l_max_tokens=l_max_tokens,
tau=tau_threshold,
doc_id=doc_id,
)

# Create multi-hop chunks (modified to ensure no duplicates)
# Create multi-hop chunks
multihop = _multihop_chunking(
single_hop_chunks,
h_min=h_min,
Expand All @@ -278,8 +293,8 @@ def run(config: Dict[str, Any]) -> None:
all_multihop_chunks.append(multihop)
all_chunk_info_metrics.append(chunk_metrics)

# Optional: Save aggregated similarity plot
if all_similarities is not None and len(all_similarities) > 0 and debug_mode:
# Optional: Save aggregated similarity plot only if in semantic_chunking and debug
if chunking_mode == "semantic_chunking" and all_similarities and debug_mode:
_plot_aggregated_similarities(all_similarities)

# Convert dataclasses back to dicts for safe addition to the dataset
Expand Down Expand Up @@ -386,7 +401,7 @@ def _compute_embeddings(
return embeddings


def _chunk_document(
def _chunk_document_semantic(
sentences: list[str],
similarities: list[float],
l_min_tokens: int,
Expand All @@ -395,9 +410,9 @@ def _chunk_document(
doc_id: str,
) -> list[SingleHopChunk]:
"""
Creates single-hop chunks from sentences, ensuring each chunk is at least
l_min_tokens in length and at most l_max_tokens, and introducing a chunk
boundary when consecutive sentence similarity is below a threshold tau.
Creates single-hop chunks from sentences using semantic guidance. Ensures each
chunk is at least l_min_tokens in length and at most l_max_tokens, introducing
a chunk boundary when consecutive sentence similarity is below threshold tau.

Args:
sentences (list[str]): The list of sentences for a single document.
Expand Down Expand Up @@ -463,6 +478,55 @@ def _chunk_document(
return chunks


def _chunk_document_fast(
sentences: list[str],
l_max_tokens: int,
doc_id: str,
) -> list[SingleHopChunk]:
"""
Creates chunks based purely on a maximum token length. Each sentence is added
to the current chunk if it does not exceed l_max_tokens; otherwise, a new chunk
is started.

Args:
sentences (list[str]): The list of sentences for a single document.
l_max_tokens (int): Maximum tokens per chunk.
doc_id (str): Unique identifier for the document.

Returns:
list[SingleHopChunk]: A list of SingleHopChunk objects.
"""
chunks: list[SingleHopChunk] = []
current_chunk: list[str] = []
current_len: int = 0
chunk_index: int = 0

for sentence in sentences:
sentence_token_count = len(sentence.split())

# If adding this sentence would exceed l_max_tokens, finalize current chunk
if current_len + sentence_token_count > l_max_tokens:
if current_chunk:
chunk_str = " ".join(current_chunk)
chunks.append(SingleHopChunk(chunk_id=f"{doc_id}_{chunk_index}", chunk_text=chunk_str))
chunk_index += 1

# Start a new chunk with the current sentence
current_chunk = [sentence]
current_len = sentence_token_count
else:
# Add sentence to current chunk
current_chunk.append(sentence)
current_len += sentence_token_count

# Any leftover chunk
if current_chunk:
chunk_str = " ".join(current_chunk)
chunks.append(SingleHopChunk(chunk_id=f"{doc_id}_{chunk_index}", chunk_text=chunk_str))

return chunks


def _multihop_chunking(
single_hop_chunks: list[SingleHopChunk],
h_min: int,
Expand Down