Skip to content

Commit b92dc74

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[IFRT] Use SerDes versioning in most IFRT and PjRt-IFRT types
The current serialization format is anchored as `SerDesVersionNumber(0)`. Existing serialization protos created by the previous version of IFRT will be compatible because the version field will be read as 0 (proto default value), which will be recognized as `SerDesVersionNumber(0)` and use the identical deserialization ligic as before. All `ToProto()`s get a `version_number` parameter (defaults to `SerDesVersion::current().version_number().value()`), and use it for internal format selection and forward it to nested serialization. All `Serialize()`s recognize `options` and get a version if provided (if omitted, `SerDesVersion::current()` is assumed), and behave similarly to `ToProto()`. `Deserialize()`/`FromProto()` will explicitly reject the serialized data if its version is no longer supported, if the updated serialization format tracks the serailzation version. Some places that have limited versioning support: `HloProgram` and `IfrtIRProgram` currently pin its internal serialization including VHLO and IFRT IR to a at-least-4-week-old version. This may change to accept a broader range of versions in case it is beneficial to be able to use the latest SerDes version instead of 4-week-old version within a single-versioned system. `PluginProgram` ignores versioning. Any change to its format currently is a breaking change because its serialization format is opaque/not extensible. Its use will also be replaced with `CustomCallProgram`, so it seems not very beneficial to add versioning to it. IFRT Proxy does not yet use SerDes versioning. It will be done in a subsequent change. PiperOrigin-RevId: 771280809
1 parent 0dfa7a5 commit b92dc74

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+925
-179
lines changed

xla/python/ifrt/BUILD

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ cc_library(
8585
":ref_wrapper",
8686
":remap_plan_proto_cc",
8787
":serdes",
88+
":serdes_version",
8889
":shape_proto_cc",
8990
":sharding_proto_cc",
9091
":user_context",
@@ -152,7 +153,9 @@ cc_library(
152153
]),
153154
deps = [
154155
":attribute_map_proto_cc",
156+
":serdes_version",
155157
"@com_google_absl//absl/container:flat_hash_map",
158+
"@com_google_absl//absl/log:check",
156159
"@com_google_absl//absl/status",
157160
"@com_google_absl//absl/status:statusor",
158161
"@com_google_absl//absl/strings",
@@ -165,6 +168,8 @@ xla_cc_test(
165168
srcs = ["attribute_map_test.cc"],
166169
deps = [
167170
":attribute_map",
171+
":serdes_test_util",
172+
":serdes_version",
168173
"//xla/tsl/platform:statusor",
169174
"@com_google_googletest//:gtest_main",
170175
],
@@ -211,6 +216,8 @@ xla_cc_test(
211216
":attribute_map",
212217
":execute_options_proto_cc",
213218
":ifrt",
219+
":serdes_test_util",
220+
":serdes_version",
214221
"//xla/tsl/platform:statusor",
215222
"//xla/tsl/platform:test",
216223
"@com_google_googletest//:gtest_main",
@@ -296,8 +303,10 @@ cc_library(
296303
":ifrt",
297304
":layout_serdes_proto_cc",
298305
":serdes",
306+
":serdes_version",
299307
"@com_google_absl//absl/status",
300308
"@com_google_absl//absl/status:statusor",
309+
"@com_google_absl//absl/strings",
301310
"@com_google_absl//absl/strings:string_view",
302311
"@llvm-project//llvm:Support",
303312
],
@@ -312,6 +321,8 @@ xla_cc_test(
312321
":ifrt",
313322
":layout_serdes",
314323
":serdes",
324+
":serdes_test_util",
325+
":serdes_version",
315326
"//xla/tsl/platform:statusor",
316327
"@com_google_googletest//:gtest_main",
317328
"@llvm-project//llvm:Support",
@@ -363,6 +374,8 @@ xla_cc_test(
363374
srcs = ["shape_test.cc"],
364375
deps = [
365376
":ifrt",
377+
":serdes_test_util",
378+
":serdes_version",
366379
":shape_proto_cc",
367380
"//xla/tsl/platform:status_matchers",
368381
"//xla/tsl/platform:statusor",
@@ -611,6 +624,7 @@ cc_library(
611624
deps = [
612625
":serdes_any_version_accessor",
613626
":serdes_version",
627+
":serdes_week_4_old_version_accessor",
614628
],
615629
)
616630

@@ -697,11 +711,13 @@ cc_library(
697711
deps = [
698712
":ifrt",
699713
":serdes",
714+
":serdes_version",
700715
":sharding_serdes_proto_cc",
701716
"//xla/python/ifrt/ir:sharding_param",
702717
"//xla/tsl/platform:statusor",
703718
"@com_google_absl//absl/status",
704719
"@com_google_absl//absl/status:statusor",
720+
"@com_google_absl//absl/strings",
705721
"@com_google_absl//absl/strings:string_view",
706722
"@llvm-project//llvm:Support",
707723
],
@@ -716,6 +732,8 @@ xla_cc_test(
716732
":ifrt",
717733
":serdes",
718734
":serdes_proto_cc",
735+
":serdes_test_util",
736+
":serdes_version",
719737
":sharding_serdes",
720738
"//xla/python/ifrt/ir:sharding_param",
721739
"//xla/tsl/platform:statusor",
@@ -747,6 +765,8 @@ xla_cc_test(
747765
":array_spec_proto_cc",
748766
":device_test_util",
749767
":ifrt",
768+
":serdes_test_util",
769+
":serdes_version",
750770
":sharding_serdes",
751771
"//xla:shape_util",
752772
"//xla/pjrt:pjrt_layout",
@@ -776,6 +796,8 @@ xla_cc_test(
776796
":device_proto_cc",
777797
":device_test_util",
778798
":ifrt",
799+
":serdes_test_util",
800+
":serdes_version",
779801
"//xla/tsl/platform:env",
780802
"//xla/tsl/platform:statusor",
781803
"@com_google_absl//absl/types:span",
@@ -824,6 +846,8 @@ xla_cc_test(
824846
deps = [
825847
":dtype_proto_cc",
826848
":ifrt",
849+
":serdes_test_util",
850+
":serdes_version",
827851
"//xla/tsl/platform:statusor",
828852
"//xla/tsl/platform:test",
829853
"@com_google_googletest//:gtest_main",
@@ -849,6 +873,8 @@ xla_cc_test(
849873
":device_test_util",
850874
":ifrt",
851875
":remap_plan_proto_cc",
876+
":serdes_test_util",
877+
":serdes_version",
852878
":sharding_serdes",
853879
"//xla:shape_util",
854880
"//xla/pjrt:pjrt_layout",
@@ -970,6 +996,7 @@ xla_cc_test(
970996
":plugin_program_serdes",
971997
":serdes",
972998
":serdes_proto_cc",
999+
":serdes_version",
9731000
"//xla/tsl/lib/core:status_test_util",
9741001
"//xla/tsl/platform:statusor",
9751002
"//xla/tsl/protobuf:error_codes_proto_impl_cc",
@@ -1026,12 +1053,13 @@ cc_library(
10261053
":ifrt",
10271054
":program_serdes",
10281055
":serdes",
1056+
":serdes_version",
10291057
":sharding_proto_cc",
10301058
":sharding_serdes",
1031-
"//xla/tsl/concurrency:ref_count",
10321059
"//xla/tsl/platform:statusor",
10331060
"@com_google_absl//absl/status",
10341061
"@com_google_absl//absl/status:statusor",
1062+
"@com_google_absl//absl/strings",
10351063
"@com_google_absl//absl/strings:cord",
10361064
"@com_google_absl//absl/strings:string_view",
10371065
"@llvm-project//llvm:Support",
@@ -1065,6 +1093,8 @@ xla_cc_test(
10651093
":program_serdes",
10661094
":serdes",
10671095
":serdes_proto_cc",
1096+
":serdes_test_util",
1097+
":serdes_version",
10681098
"//xla/tsl/lib/core:status_test_util",
10691099
"//xla/tsl/platform:status_matchers",
10701100
"//xla/tsl/platform:statusor",

xla/python/ifrt/array_spec.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ limitations under the License.
1919
#include <string>
2020
#include <utility>
2121

22+
#include "absl/status/status.h"
2223
#include "absl/status/statusor.h"
2324
#include "absl/strings/str_cat.h"
2425
#include "xla/pjrt/pjrt_layout.h"
2526
#include "xla/python/ifrt/array_spec.pb.h"
2627
#include "xla/python/ifrt/client.h"
27-
#include "xla/python/ifrt/device_list.h"
2828
#include "xla/python/ifrt/dtype.h"
29+
#include "xla/python/ifrt/serdes_version.h"
2930
#include "xla/python/ifrt/shape.h"
3031
#include "xla/python/ifrt/sharding.h"
3132
#include "xla/tsl/platform/statusor.h"
@@ -35,6 +36,12 @@ namespace ifrt {
3536

3637
absl::StatusOr<ArraySpec> ArraySpec::FromProto(Client* client,
3738
const ArraySpecProto& proto) {
39+
const SerDesVersionNumber version_number(proto.version_number());
40+
if (version_number != SerDesVersionNumber(0)) {
41+
return absl::FailedPreconditionError(absl::StrCat(
42+
"Unsupported ", version_number, " for ArraySpec deserialization"));
43+
}
44+
3845
TF_ASSIGN_OR_RETURN(auto dtype, DType::FromProto(proto.dtype()));
3946
TF_ASSIGN_OR_RETURN(auto shape, Shape::FromProto(proto.shape()));
4047
TF_ASSIGN_OR_RETURN(auto sharding,
@@ -51,11 +58,18 @@ absl::StatusOr<ArraySpec> ArraySpec::FromProto(Client* client,
5158
};
5259
}
5360

54-
absl::StatusOr<ArraySpecProto> ArraySpec::ToProto() const {
61+
absl::StatusOr<ArraySpecProto> ArraySpec::ToProto(SerDesVersion version) const {
62+
if (version.version_number() < SerDesVersionNumber(0)) {
63+
return absl::FailedPreconditionError(
64+
absl::StrCat("Unsupported ", version.version_number(),
65+
" for ArraySpec serialization"));
66+
}
67+
5568
ArraySpecProto proto;
56-
*proto.mutable_dtype() = dtype.ToProto();
57-
*proto.mutable_shape() = shape.ToProto();
58-
TF_ASSIGN_OR_RETURN(*proto.mutable_sharding(), sharding->ToProto());
69+
proto.set_version_number(SerDesVersionNumber(0).value());
70+
*proto.mutable_dtype() = dtype.ToProto(version);
71+
*proto.mutable_shape() = shape.ToProto(version);
72+
TF_ASSIGN_OR_RETURN(*proto.mutable_sharding(), sharding->ToProto(version));
5973
if (layout != nullptr) {
6074
proto.set_layout(layout->Serialize());
6175
}

xla/python/ifrt/array_spec.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "xla/pjrt/pjrt_layout.h"
2626
#include "xla/python/ifrt/array_spec.pb.h"
2727
#include "xla/python/ifrt/dtype.h"
28+
#include "xla/python/ifrt/serdes_version.h"
2829
#include "xla/python/ifrt/shape.h"
2930
#include "xla/python/ifrt/sharding.h"
3031

@@ -77,7 +78,8 @@ struct ArraySpec {
7778
const ArraySpecProto& proto);
7879

7980
// Returns a `ArraySpecProto` representation.
80-
absl::StatusOr<ArraySpecProto> ToProto() const;
81+
absl::StatusOr<ArraySpecProto> ToProto(
82+
SerDesVersion version = SerDesVersion::current()) const;
8183

8284
// TODO(hyeontaek): Remove this method in favor of AbslStringify.
8385
std::string DebugString() const;

xla/python/ifrt/array_spec.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import "xla/python/ifrt/sharding.proto";
2323

2424
// Proto equivalent of C++ `ArraySpec`.
2525
message ArraySpecProto {
26+
int32 version_number = 5;
27+
2628
DTypeProto dtype = 1;
2729
ShapeProto shape = 2;
2830
ShardingProto sharding = 3;

xla/python/ifrt/array_spec_test.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/python/ifrt/array_spec.h"
1717

1818
#include <memory>
19+
#include <tuple>
1920

2021
#include <gtest/gtest.h>
2122
#include "absl/hash/hash_testing.h"
@@ -28,6 +29,8 @@ limitations under the License.
2829
#include "xla/python/ifrt/device_test_util.h"
2930
#include "xla/python/ifrt/dtype.h"
3031
#include "xla/python/ifrt/memory.h"
32+
#include "xla/python/ifrt/serdes_test_util.h"
33+
#include "xla/python/ifrt/serdes_version.h"
3134
#include "xla/python/ifrt/shape.h"
3235
#include "xla/python/ifrt/sharding.h"
3336
#include "xla/tsl/platform/statusor.h"
@@ -36,18 +39,23 @@ namespace xla {
3639
namespace ifrt {
3740
namespace {
3841

39-
using ArraySpecTestParam = test_util::DeviceTestParam;
42+
using ArraySpecTestParam =
43+
std::tuple<SerDesVersion, test_util::DeviceTestParam>;
4044

4145
class ArraySpecTest : public testing::TestWithParam<ArraySpecTestParam> {
4246
public:
43-
ArraySpecTest() : fixture_(GetParam()) {}
47+
ArraySpecTest()
48+
: version_(std::get<0>(GetParam())), fixture_(std::get<1>(GetParam())) {}
49+
50+
SerDesVersion version() const { return version_; }
4451

4552
Client* client() { return fixture_.client(); }
4653
DeviceListRef GetDevices(absl::Span<const int> device_indices) {
4754
return fixture_.GetDevices(device_indices);
4855
}
4956

5057
private:
58+
SerDesVersion version_;
5159
test_util::DeviceTestFixture fixture_;
5260
};
5361

@@ -77,7 +85,7 @@ TEST_P(ArraySpecTest, ToFromProto) {
7785
/*shape=*/shape,
7886
/*shard_shape=*/shard_shape)};
7987

80-
TF_ASSERT_OK_AND_ASSIGN(const ArraySpecProto proto, spec.ToProto());
88+
TF_ASSERT_OK_AND_ASSIGN(const ArraySpecProto proto, spec.ToProto(version()));
8189
TF_ASSERT_OK_AND_ASSIGN(const ArraySpec array_spec_copy,
8290
ArraySpec::FromProto(client(), proto));
8391

@@ -93,10 +101,12 @@ TEST_P(ArraySpecTest, ToFromProto) {
93101
EXPECT_EQ(sharding->shard_shape(), shard_shape);
94102
}
95103

96-
INSTANTIATE_TEST_SUITE_P(NumDevices, ArraySpecTest,
97-
testing::Values(test_util::DeviceTestParam{
98-
/*num_devices=*/2,
99-
/*num_addressable_devices=*/2}));
104+
INSTANTIATE_TEST_SUITE_P(
105+
SerDesVersion_NumDevices, ArraySpecTest,
106+
testing::Combine(testing::ValuesIn(test_util::AllSupportedSerDesVersions()),
107+
testing::Values(test_util::DeviceTestParam{
108+
/*num_devices=*/2,
109+
/*num_addressable_devices=*/2})));
100110

101111
} // namespace
102112
} // namespace ifrt

xla/python/ifrt/attribute_map.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,25 @@ limitations under the License.
2323
#include <variant>
2424
#include <vector>
2525

26+
#include "absl/log/check.h"
2627
#include "absl/status/status.h"
2728
#include "absl/status/statusor.h"
2829
#include "absl/strings/str_cat.h"
2930
#include "absl/strings/str_join.h"
3031
#include "xla/python/ifrt/attribute_map.pb.h"
32+
#include "xla/python/ifrt/serdes_version.h"
3133

3234
namespace xla {
3335
namespace ifrt {
3436

3537
absl::StatusOr<AttributeMap> AttributeMap::FromProto(
3638
const AttributeMapProto& proto) {
39+
const SerDesVersionNumber version_number(proto.version_number());
40+
if (version_number != SerDesVersionNumber(0)) {
41+
return absl::FailedPreconditionError(absl::StrCat(
42+
"Unsupported ", version_number, " for AttributeMap deserialization"));
43+
}
44+
3745
AttributeMap::Map map;
3846
map.reserve(proto.attributes_size());
3947
for (const auto& [key, value] : proto.attributes()) {
@@ -63,8 +71,16 @@ absl::StatusOr<AttributeMap> AttributeMap::FromProto(
6371
return AttributeMap(std::move(map));
6472
}
6573

66-
AttributeMapProto AttributeMap::ToProto() const {
74+
AttributeMapProto AttributeMap::ToProto(SerDesVersion version) const {
75+
// TODO(b/423702568): Change the return type to `absl::StatusOr<...>` for
76+
// graceful error handling.
77+
CHECK_GE(version.version_number(), SerDesVersionNumber(0))
78+
<< "Unsupported " << version.version_number()
79+
<< " for AttributeMap serialization";
80+
6781
AttributeMapProto proto;
82+
proto.set_version_number(SerDesVersionNumber(0).value());
83+
6884
for (const auto& [key, value] : map_) {
6985
AttributeMapProto::Value value_proto;
7086
std::visit(

xla/python/ifrt/attribute_map.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "absl/container/flat_hash_map.h"
2727
#include "absl/status/statusor.h"
2828
#include "xla/python/ifrt/attribute_map.pb.h"
29+
#include "xla/python/ifrt/serdes_version.h"
2930

3031
namespace xla {
3132
namespace ifrt {
@@ -88,7 +89,8 @@ class AttributeMap {
8889
static absl::StatusOr<AttributeMap> FromProto(const AttributeMapProto& proto);
8990

9091
// Serializes `AttributeMap` into `AttributeMapProto`.
91-
AttributeMapProto ToProto() const;
92+
AttributeMapProto ToProto(
93+
SerDesVersion version = SerDesVersion::current()) const;
9294

9395
std::string DebugString(size_t max_string_length = 64,
9496
size_t max_int64_list_size = 16) const;

xla/python/ifrt/attribute_map.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package xla.ifrt;
1919

2020
// Proto equivalent of C++ `AttributeMap`.
2121
message AttributeMapProto {
22+
int32 version_number = 2;
23+
2224
message Value {
2325
message Int64List {
2426
repeated sfixed64 elements = 1;

0 commit comments

Comments
 (0)