Skip to content

Commit fcae2d2

Browse files
reedwmcopybara-github
authored andcommitted
Update ml_dtypes version to 0fa5313b65efe848c5968a15dd37dd220cc29567.
Also add mxfloat as a dependency to TensorFlow and TSL. This is needed to merge openxla/xla#19096. Previously this was done in the merge commit for that PR, but the PR was rolled back since the new types caused an internal TF Android build to fail. Now it's being done in this separate, smaller change so its easier to rollback if issues occur. PiperOrigin-RevId: 713856483
1 parent cbce6a9 commit fcae2d2

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

third_party/py/ml_dtypes/workspace.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ float8 varieties, and int4.
77
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
88

99
def repo():
10-
ML_DTYPES_COMMIT = "215c9f02a121e6286662b2efd30546c71054d5e5"
11-
ML_DTYPES_SHA256 = "4a03237ef6345e1467a33d126176b9c6a7539b0f60a34b344f39b3c9e8b82438"
10+
ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567"
11+
ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641"
1212
tf_http_archive(
1313
name = "ml_dtypes",
1414
build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD",

tsl/platform/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,7 @@ cc_library(
984984
deps = [
985985
"@ml_dtypes//:float8",
986986
"@ml_dtypes//:intn",
987+
"@ml_dtypes//:mxfloat",
987988
],
988989
)
989990

tsl/platform/ml_dtypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@ limitations under the License.
1818

1919
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
2020
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
21+
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes
2122

2223
namespace tsl {
24+
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
2325
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
2426
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
2527
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
2628
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
2729
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
2830
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
2931
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;
32+
using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu;
3033

3134
using int1 = ::ml_dtypes::int1;
3235
using uint1 = ::ml_dtypes::uint1;

0 commit comments

Comments
 (0)