Skip to content

Commit 10f7353

Browse files
committed
fix bug
Signed-off-by: Yaoyao Ding <[email protected]>
1 parent fe4a9e6 commit 10f7353

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

examples/quantization/matmul_a16wx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,4 +609,7 @@ def main():
609609

610610

611611
if __name__ == "__main__":
612-
main()
612+
tilus.utils.clear_cache()
613+
a = tilus.tensor.randn([1000], dtype=tilus.float3_e1m1)
614+
# b = a.to(tilus.float16)
615+
# main()

python/tilus/extensions/hidet/transforms/lower_subbyte_type.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,22 @@ def visit_DeclareStmt(self, stmt: DeclareStmt) -> Stmt:
440440
return super().visit_DeclareStmt(stmt)
441441

442442
def visit_AssignStmt(self, stmt: AssignStmt) -> Stmt:
443-
if isinstance(stmt.var.type, PointerType) and is_subbyte(get_base_type(stmt.var.type)):
443+
if isinstance(stmt.var.type, (PointerType, TensorPointerType, TensorType)) and is_subbyte(
444+
get_base_type(stmt.var.type)
445+
):
446+
sb = StmtBuilder()
447+
lhs_uint8_pointer = self.uint8_pointer[stmt.var]
448+
lhs_bit_offset = self.bit_offset[stmt.var]
449+
assert isinstance(lhs_uint8_pointer, Var)
450+
assert isinstance(lhs_bit_offset, Var)
451+
value_uint8_pointer, value_bit_offset = self.get_byte_and_bit_offset(stmt.value)
452+
sb.assign(lhs_uint8_pointer, value=value_uint8_pointer)
453+
sb.assign(lhs_bit_offset, value=value_bit_offset)
454+
return sb.finish()
455+
elif isinstance(stmt.var.type, DataType) and is_subbyte(stmt.var.type):
444456
raise NotImplementedError()
445-
return super().visit_AssignStmt(stmt)
457+
else:
458+
return super().visit_AssignStmt(stmt)
446459

447460
def visit_LetStmt(self, stmt: LetStmt) -> Stmt:
448461
bind_vars = []

0 commit comments

Comments
 (0)