Skip to content

[6/N][Refactor] torchair model runner refactor #2220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,295 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
import types
from typing import Optional

import torch
import torch.nn as nn
import torch_npu
import torchair
import vllm.envs as envs_vllm
from torchair import patch_for_hcom
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import logger

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
write_kv_cache_bytes_to_file)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_310p, maybe_converting_weight_acl_format)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner


class NPUTorchairModelRunner(NPUModelRunner):

def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)

def _select_torchair_padded_batch_size(self, batch_size: int):
selected_batch_size = self.max_num_reqs
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size < selected_batch_size:
selected_batch_size = padded_batch_size
return selected_batch_size

def _get_forward_metadata_across_dp_and_pad(
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
"""Override from NPUModelRunner to pad num_tokens"""
if self.dp_size == 1:
return num_tokens, None, with_prefill, enable_dbo

if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_across_dp, True, enable_dbo

if self.is_kv_consumer and len(self.torchair_graph_batch_sizes
) == 1 and not self.in_profile_run:
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo

num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
num_tokens, with_prefill, enable_dbo)

if not with_prefill:
max_num_token = num_tokens_across_dp.max().item()
maybe_padded_num_tokens = self._select_torchair_padded_batch_size(
max_num_token)
num_tokens_across_dp = torch.full((self.dp_size, ),
maybe_padded_num_tokens,
dtype=torch.int32,
device="cpu")
else:
maybe_padded_num_tokens = num_tokens

return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo

def _build_attention_metadata(self, with_prefill, num_tokens, skip_attn):
"""Override from NPUModelRunner to build attention metadata."""
if not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
else:
attn_metadata = super()._build_attention_metadata(
with_prefill, num_tokens, skip_attn)
return attn_metadata

def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.max_num_reqs:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
)

compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.use_cached_npu_graph else self.torchair_compiled_model

if compiled_model:
return compiled_model

patch_for_hcom()

if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()

config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )

modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)

self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]

def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
padded_num_tokens_across_dp,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
"""Override from NPUModelRunner to use compiled_model"""
if not with_prefill:
model_kwargs = {}
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)

compiled_model = self._get_torchair_lazy_compiled_model(
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
hidden_states = super()._generate_process_reqs_hidden_states(
attn_metadata, with_prefill, padded_num_tokens_across_dp,
input_ids, positions, intermediate_tensors, inputs_embeds)
return hidden_states

def _generate_dummy_run_hidden_states(self, with_prefill,
is_torchair_compile, input_ids,
positions, attn_metadata, num_tokens,
intermediate_tensors, inputs_embeds):
"""Override from NPUModelRunner to use compiled_model"""
model_kwargs = {}
if with_prefill:
# Only mark static while compiling
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])

maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_NZ)

compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
**model_kwargs,
)
else:
hidden_states = super()._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)

return hidden_states

def _convert_torch_foramt(self, kv_cache):
"""Override from NPUModelRunner to use torchair format."""
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
return kv_cache

def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))

def _capture_model(self):
"""Override from NPUModelRunner to use torchair graph capture."""
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# torchair graph capture can cause some issues, so now we just
# temporarily split the codepath for the two different graph patterns.
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
graph_num = len(torchair_graph_batch_sizes)

if self.use_cached_npu_graph and not check_torchair_cache_exist():
# If caching is enabled but does not exist, we will compile the model twice. The first
# time is used to generate the cache, and the second time is used to load the cache to
# skip the overhead caused by Dynamo guard mechanism.
logger.info(
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
NPUPlatform.synchronize()
torch._dynamo.reset()
self.torchair_compiled_models.clear()
if self.use_cached_npu_graph:
logger.info(
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
0.3 * graph_num, 0.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)
else:
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
self._compile_torchair_graph(torchair_graph_batch_sizes)

if self.new_kv_cache_bytes > 0:
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
self.new_kv_cache_bytes)

def _generate_extra_builder_kwargs(self, enable_dbo, num_reqs,

Check failure on line 287 in vllm_ascend/torchair/torchair_model_runner.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Missing return statement [return]

Check failure on line 287 in vllm_ascend/torchair/torchair_model_runner.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Missing return statement [return]
with_prefill,
padded_num_tokens_across_dp,
total_num_scheduled_tokens) -> dict:
"""Override from NPUModelRunner to add graph_pad_size."""
extra_builder_kwargs = super()._generate_extra_builder_kwargs(
enable_dbo, num_reqs, with_prefill, padded_num_tokens_across_dp,
total_num_scheduled_tokens)
if not with_prefill:
extra_builder_kwargs[
'graph_pad_size'] = padded_num_tokens_across_dp - total_num_scheduled_tokens

def _update_input_ids_and_positions(self, input_ids, positions,
num_input_tokens, with_prefill,
padded_num_tokens_across_dp):
"""Override from NPUModelRunner to update input_ids and positions"""
input_ids, positions = super()._update_input_ids_and_positions(
input_ids, positions, num_input_tokens, with_prefill,
padded_num_tokens_across_dp)

if not with_prefill:
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
return input_ids, positions
20 changes: 10 additions & 10 deletions vllm_ascend/torchair/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@ def _get_torchair_current_work_dir(file_name=None):
return os.path.join(TORCHAIR_CACHE_DIR, file_name)


def check_torchair_cache_exist():
res = False
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res


def check_kv_cache_bytes_cache_exist():
res = False
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
Expand Down Expand Up @@ -96,3 +86,13 @@ def npu_wait_tensor(self: torch.Tensor,
*,
enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self


def check_torchair_cache_exist():
res = False
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res
Loading
Loading