Skip to content

Commit 726907c

Browse files
authored
Merge pull request #32 from yupbank/add-libsvm
Add libsvm dataset support
2 parents 24eb10d + ccc6b10 commit 726907c

File tree

8 files changed

+464
-0
lines changed

8 files changed

+464
-0
lines changed

tensorflow_io/libsvm/BUILD

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
cc_binary(
6+
name = "python/ops/_libsvm_ops.so",
7+
srcs = [
8+
"kernels/decode_libsvm_op.cc",
9+
"ops/libsvm_ops.cc",
10+
],
11+
linkshared = 1,
12+
deps = [
13+
"@local_config_tf//:libtensorflow_framework",
14+
"@local_config_tf//:tf_header_lib",
15+
],
16+
copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0", "-DNDEBUG"]
17+
)
18+
19+
py_library(
20+
name = "libsvm_ops_py",
21+
srcs = [
22+
"python/ops/libsvm_dataset_ops.py",
23+
],
24+
data = [
25+
":python/ops/_libsvm_ops.so",
26+
],
27+
srcs_version = "PY2AND3",
28+
)
29+
30+
py_test(
31+
name = "decode_libsvm_op_test",
32+
srcs = [
33+
"python/kernel_tests/decode_libsvm_op_test.py"
34+
],
35+
main = "python/kernel_tests/decode_libsvm_op_test.py",
36+
deps = [
37+
":libsvm_ops_py",
38+
],
39+
srcs_version = "PY2AND3",
40+
)
41+
42+
py_library(
43+
name = "libsvm_py",
44+
srcs = ([
45+
"__init__.py",
46+
"python/__init__.py",
47+
"python/ops/__init__.py",
48+
]),
49+
deps = [
50+
":libsvm_ops_py"
51+
],
52+
srcs_version = "PY2AND3",
53+
)

tensorflow_io/libsvm/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""LibSVM Dataset.
16+
17+
@@make_libsvm_dataset
18+
"""
19+
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
from tensorflow.contrib.libsvm.python.ops.libsvm_dataset_ops import make_libsvm_dataset
25+
26+
from tensorflow.python.util.all_util import remove_undocumented
27+
28+
_allowed_symbols = [
29+
"make_libsvm_dataset",
30+
]
31+
32+
remove_undocumented(__name__)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/op_kernel.h"
17+
#include "tensorflow/core/framework/tensor.h"
18+
#include "tensorflow/core/framework/tensor_shape.h"
19+
#include "tensorflow/core/framework/types.h"
20+
#include "tensorflow/core/lib/core/errors.h"
21+
#include "tensorflow/core/lib/strings/numbers.h"
22+
#include "tensorflow/core/lib/strings/str_util.h"
23+
24+
namespace tensorflow {
25+
26+
template <typename T, typename Tlabel>
27+
class DecodeLibsvmOp : public OpKernel {
28+
public:
29+
explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
30+
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_));
31+
OP_REQUIRES(ctx, (num_features_ >= 1),
32+
errors::InvalidArgument("Invalid number of features \"",
33+
num_features_, "\""));
34+
}
35+
36+
void Compute(OpKernelContext* ctx) override {
37+
const Tensor* input_tensor;
38+
OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
39+
const auto& input_flat = input_tensor->flat<string>();
40+
41+
Tensor* label_tensor;
42+
OP_REQUIRES_OK(
43+
ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor));
44+
auto label = label_tensor->flat<Tlabel>();
45+
46+
std::vector<T> out_values;
47+
std::vector<std::pair<int64, int64>> out_indices;
48+
for (int i = 0; i < input_flat.size(); ++i) {
49+
StringPiece line(input_flat(i));
50+
str_util::RemoveWhitespaceContext(&line);
51+
52+
StringPiece piece;
53+
OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece),
54+
errors::InvalidArgument("No label found for input[", i,
55+
"]: \"", input_flat(i), "\""));
56+
57+
Tlabel label_value;
58+
OP_REQUIRES(ctx,
59+
strings::SafeStringToNumeric<Tlabel>(piece, &label_value),
60+
errors::InvalidArgument("Label format incorrect: ", piece));
61+
62+
label(i) = label_value;
63+
64+
str_util::RemoveLeadingWhitespace(&line);
65+
while (str_util::ConsumeNonWhitespace(&line, &piece)) {
66+
size_t p = piece.find(':');
67+
OP_REQUIRES(ctx, (p != StringPiece::npos),
68+
errors::InvalidArgument("Invalid feature \"", piece, "\""));
69+
70+
int64 feature_index;
71+
OP_REQUIRES(
72+
ctx, strings::safe_strto64(piece.substr(0, p), &feature_index),
73+
errors::InvalidArgument("Feature format incorrect: ", piece));
74+
OP_REQUIRES(ctx, (feature_index >= 0),
75+
errors::InvalidArgument(
76+
"Feature index should be >= 0, got ", feature_index));
77+
78+
T feature_value;
79+
OP_REQUIRES(
80+
81+
ctx,
82+
strings::SafeStringToNumeric<T>(piece.substr(p + 1),
83+
&feature_value),
84+
errors::InvalidArgument("Feature format incorrect: ", piece));
85+
86+
out_values.emplace_back(feature_value);
87+
out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
88+
89+
str_util::RemoveLeadingWhitespace(&line);
90+
}
91+
}
92+
93+
Tensor* indices_tensor;
94+
OP_REQUIRES_OK(ctx, ctx->allocate_output(
95+
1,
96+
TensorShape({static_cast<int64>(out_indices.size()),
97+
input_tensor->shape().dims() + 1}),
98+
&indices_tensor));
99+
auto indices = indices_tensor->matrix<int64>();
100+
// Translate flat index to shaped index like np.unravel_index
101+
// Calculate factors for each dimension
102+
std::vector<int64> factors(input_tensor->shape().dims());
103+
factors[input_tensor->shape().dims() - 1] = 1;
104+
for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) {
105+
factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1);
106+
}
107+
for (int i = 0; i < out_indices.size(); i++) {
108+
indices(i, 0) = out_indices[i].first;
109+
int64 value = out_indices[i].first;
110+
for (int j = 0; j < input_tensor->shape().dims(); j++) {
111+
indices(i, j) = value / factors[j];
112+
value = value % factors[j];
113+
}
114+
indices(i, input_tensor->shape().dims()) = out_indices[i].second;
115+
}
116+
117+
Tensor* values_tensor;
118+
OP_REQUIRES_OK(ctx,
119+
ctx->allocate_output(
120+
2, TensorShape({static_cast<int64>(out_values.size())}),
121+
&values_tensor));
122+
auto values = values_tensor->vec<T>();
123+
std::copy_n(out_values.begin(), out_values.size(), &values(0));
124+
125+
Tensor* shape_tensor;
126+
OP_REQUIRES_OK(ctx, ctx->allocate_output(
127+
3, TensorShape({input_tensor->shape().dims() + 1}),
128+
&shape_tensor));
129+
auto shape = shape_tensor->flat<int64>();
130+
for (int i = 0; i < input_tensor->shape().dims(); i++) {
131+
shape(i) = input_tensor->shape().dim_size(i);
132+
}
133+
shape(input_tensor->shape().dims()) = num_features_;
134+
}
135+
136+
private:
137+
int64 num_features_;
138+
};
139+
140+
#define REGISTER_KERNEL(type) \
141+
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
142+
.Device(DEVICE_CPU) \
143+
.TypeConstraint<type>("dtype") \
144+
.TypeConstraint<int32>("label_dtype"), \
145+
DecodeLibsvmOp<type, int32>); \
146+
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
147+
.Device(DEVICE_CPU) \
148+
.TypeConstraint<type>("dtype") \
149+
.TypeConstraint<int64>("label_dtype"), \
150+
DecodeLibsvmOp<type, int64>); \
151+
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
152+
.Device(DEVICE_CPU) \
153+
.TypeConstraint<type>("dtype") \
154+
.TypeConstraint<float>("label_dtype"), \
155+
DecodeLibsvmOp<type, float>); \
156+
REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \
157+
.Device(DEVICE_CPU) \
158+
.TypeConstraint<type>("dtype") \
159+
.TypeConstraint<double>("label_dtype"), \
160+
DecodeLibsvmOp<type, double>);
161+
162+
REGISTER_KERNEL(float);
163+
REGISTER_KERNEL(double);
164+
REGISTER_KERNEL(int32);
165+
REGISTER_KERNEL(int64);
166+
#undef REGISTER_KERNEL
167+
168+
} // namespace tensorflow
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/framework/common_shape_fns.h"
17+
#include "tensorflow/core/framework/op.h"
18+
#include "tensorflow/core/framework/shape_inference.h"
19+
20+
namespace tensorflow {
21+
22+
using shape_inference::InferenceContext;
23+
24+
REGISTER_OP("DecodeLibsvm")
25+
.Input("input: string")
26+
.Output("label: label_dtype")
27+
.Output("feature_indices: int64")
28+
.Output("feature_values: dtype")
29+
.Output("feature_shape: int64")
30+
.Attr("dtype: {float, double, int32, int64} = DT_FLOAT")
31+
.Attr("label_dtype: {float, double, int32, int64} = DT_INT64")
32+
.Attr("num_features: int >= 1")
33+
.SetShapeFn([](InferenceContext* c) {
34+
c->set_output(0, c->input(0));
35+
36+
c->set_output(1, c->Matrix(InferenceContext::kUnknownDim,
37+
InferenceContext::kUnknownDim));
38+
c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
39+
c->set_output(3, c->Vector(InferenceContext::kUnknownDim));
40+
41+
return Status::OK();
42+
})
43+
44+
.Doc(R"doc(
45+
Convert LibSVM input to tensors. The output consists of
46+
a label and a feature tensor. The shape of the label tensor
47+
is the same as input and the shape of the feature tensor is
48+
`[input_shape, num_features]`.
49+
50+
input: Each string is a record in the LibSVM.
51+
label: A tensor of the same shape as input.
52+
feature_indices: A 2-D int64 tensor of dense_shape [N, ndims].
53+
feature_values: A 1-D tensor of any type and dense_shape [N].
54+
feature_shape: A 1-D int64 tensor of dense_shape [ndims].
55+
num_features: The number of features.
56+
)doc");
57+
58+
} // namespace tensorflow

tensorflow_io/libsvm/python/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for DecodeLibsvm op."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow_io.libsvm.python.ops import libsvm_dataset_ops
24+
from tensorflow.python.framework import dtypes
25+
from tensorflow.python.ops import sparse_ops
26+
from tensorflow.python.platform import test
27+
28+
29+
class DecodeLibsvmOpTest(test.TestCase):
30+
31+
def testBasic(self):
32+
with self.cached_session() as sess:
33+
content = [
34+
"1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
35+
"2 3:2.5 2:nan 1:0.105"
36+
]
37+
sparse_features, labels = libsvm_dataset_ops.decode_libsvm(
38+
content, num_features=6)
39+
features = sparse_ops.sparse_tensor_to_dense(
40+
sparse_features, validate_indices=False)
41+
42+
self.assertAllEqual(labels.get_shape().as_list(), [3])
43+
44+
features, labels = sess.run([features, labels])
45+
self.assertAllEqual(labels, [1, 1, 2])
46+
self.assertAllClose(
47+
features, [[0, 3.4, 0.5, 0, 0.231, 0], [0, 0, 2.5, np.inf, 0, 0.503],
48+
[0, 0.105, np.nan, 2.5, 0, 0]])
49+
50+
def testNDimension(self):
51+
with self.cached_session() as sess:
52+
content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
53+
["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
54+
["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
55+
sparse_features, labels = libsvm_dataset_ops.decode_libsvm(
56+
content, num_features=6, label_dtype=dtypes.float64)
57+
features = sparse_ops.sparse_tensor_to_dense(
58+
sparse_features, validate_indices=False)
59+
60+
self.assertAllEqual(labels.get_shape().as_list(), [3, 2])
61+
62+
features, labels = sess.run([features, labels])
63+
self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]])
64+
self.assertAllClose(
65+
features, [[[0, 3.4, 0.5, 0, 0.231, 0], [0, 3.4, 0.5, 0, 0.231, 0]], [
66+
[0, 0, 2.5, np.inf, 0, 0.503], [0, 0, 2.5, np.inf, 0, 0.503]
67+
], [[0, 0.105, np.nan, 2.5, 0, 0], [0, 0.105, np.nan, 2.5, 0, 0]]])
68+
69+
70+
if __name__ == "__main__":
71+
test.main()

tensorflow_io/libsvm/python/ops/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)