Skip to content

Commit 759f169

Browse files
committed
seperated out socket client for platform specific impl
1 parent 6c0fcd3 commit 759f169

File tree

5 files changed

+316
-77
lines changed

5 files changed

+316
-77
lines changed

tensorflow_io/arrow/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ cc_binary(
66
name = 'python/ops/_arrow_ops.so',
77
srcs = [
88
"kernels/arrow_dataset_ops.cc",
9-
"ops/dataset_ops.cc"
9+
"kernels/arrow_stream_client.h",
10+
"kernels/arrow_stream_client_unix.cc",
11+
"ops/dataset_ops.cc",
1012
],
1113
linkshared = 1,
1214
deps = [

tensorflow_io/arrow/kernels/arrow_dataset_ops.cc

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
// TODO posix specific
17-
#include <arpa/inet.h>
18-
#include <sys/socket.h>
19-
2016
#include "arrow/api.h"
2117
#include "arrow/io/api.h"
2218
#include "arrow/ipc/api.h"
2319
#include "arrow/util/io-util.h"
20+
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
2421
#include "tensorflow/core/framework/dataset.h"
2522

2623
#define CHECK_ARROW(arrow_status) \
@@ -33,50 +30,6 @@ limitations under the License.
3330

3431
namespace tensorflow {
3532

36-
// Class to wrap a socket as a readable Arrow InputStream
37-
class SocketStream : public arrow::io::InputStream {
38-
public:
39-
SocketStream(int sock) : sock_(sock), pos_(0) {}
40-
~SocketStream() override {}
41-
42-
arrow::Status Close() override {
43-
close(sock_);
44-
return arrow::Status::OK();
45-
}
46-
47-
arrow::Status Tell(int64_t* position) const override {
48-
*position = pos_;
49-
return arrow::Status::OK();
50-
}
51-
52-
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
53-
int status = recv(sock_, out, nbytes, MSG_WAITALL);
54-
// if (status == 0) socket closed
55-
if (status == -1) {
56-
return arrow::Status::IOError("error reading from socket");
57-
}
58-
*bytes_read = nbytes;
59-
pos_ += *bytes_read;
60-
return arrow::Status::OK();
61-
}
62-
63-
arrow::Status Read(int64_t nbytes,
64-
std::shared_ptr<arrow::Buffer>* out) override {
65-
std::shared_ptr<arrow::ResizableBuffer> buffer;
66-
ARROW_RETURN_NOT_OK(arrow::AllocateResizableBuffer(nbytes, &buffer));
67-
int64_t bytes_read;
68-
ARROW_RETURN_NOT_OK(Read(nbytes, &bytes_read, buffer->mutable_data()));
69-
ARROW_RETURN_NOT_OK(buffer->Resize(bytes_read, false));
70-
buffer->ZeroPadding();
71-
*out = buffer;
72-
return arrow::Status::OK();
73-
}
74-
75-
private:
76-
int sock_;
77-
int64_t pos_;
78-
};
79-
8033
// Convert an element of an Arrow Array to a Tensor
8134
class ArrowConvertTensor : public arrow::ArrayVisitor {
8235
public:
@@ -593,34 +546,10 @@ class ArrowStreamDatasetOp : public ArrowOpKernelBase {
593546
if (dataset()->host_ == "STDIN") {
594547
in_stream_ = std::make_shared<arrow::io::StdinStream>();
595548
} else {
596-
size_t sep_pos = dataset()->host_.find(':');
597-
if (sep_pos == std::string::npos ||
598-
sep_pos == dataset()->host_.size()) {
599-
return errors::InvalidArgument(
600-
"Expected host to be in format <host>:<port> but got: " +
601-
dataset()->host_);
602-
}
603-
std::string host_str = dataset()->host_.substr(0, sep_pos);
604-
std::string port_str = dataset()->host_.substr(
605-
sep_pos + 1, dataset()->host_.size() - sep_pos);
606-
int port_num = std::stoi(port_str);
607-
int sock = 0;
608-
struct sockaddr_in serv_addr;
609-
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
610-
return errors::InvalidArgument("Socket creation error");
611-
}
612-
bzero((char*)&serv_addr, sizeof(serv_addr));
613-
serv_addr.sin_addr.s_addr = inet_addr(host_str.c_str());
614-
serv_addr.sin_family = AF_INET;
615-
serv_addr.sin_port = htons(port_num);
616-
// if(inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr)<=0)
617-
// printf("\nInvalid address/ Address not supported \n");
618-
if (connect(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) <
619-
0) {
620-
return errors::InvalidArgument("Connection failed to host: " +
621-
dataset()->host_);
622-
}
623-
in_stream_ = std::make_shared<SocketStream>(sock);
549+
auto socket_stream =
550+
std::make_shared<ArrowStreamClient>(dataset()->host_);
551+
CHECK_ARROW(socket_stream->Connect());
552+
in_stream_ = socket_stream;
624553
}
625554

626555
CHECK_ARROW(arrow::ipc::RecordBatchStreamReader::Open(in_stream_.get(),
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_IO_ARROW_STREAM_CLIENT_H_
17+
#define TENSORFLOW_IO_ARROW_STREAM_CLIENT_H_
18+
19+
#include "arrow/io/api.h"
20+
21+
namespace tensorflow {
22+
23+
// Class to wrap a socket as a readable Arrow InputStream
24+
class ArrowStreamClient : public arrow::io::InputStream {
25+
public:
26+
ArrowStreamClient(const std::string& host);
27+
~ArrowStreamClient() override;
28+
29+
arrow::Status Connect();
30+
arrow::Status Close() override;
31+
arrow::Status Tell(int64_t* position) const override;
32+
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override;
33+
arrow::Status Read(int64_t nbytes,
34+
std::shared_ptr<arrow::Buffer>* out) override;
35+
36+
private:
37+
const std::string host_;
38+
int sock_;
39+
int64_t pos_;
40+
};
41+
42+
} // namespace tensorflow
43+
44+
#endif // TENSORFLOW_IO_ARROW_STREAM_CLIENT_H_
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <arpa/inet.h>
17+
#include <sys/socket.h>
18+
#include <unistd.h>
19+
20+
#include "arrow/api.h"
21+
#include "arrow/io/api.h"
22+
23+
#include "tensorflow_io/arrow/kernels/arrow_stream_client.h"
24+
25+
namespace tensorflow {
26+
27+
ArrowStreamClient::ArrowStreamClient(const std::string& host)
28+
: host_(host), sock_(-1), pos_(0) {}
29+
30+
ArrowStreamClient::~ArrowStreamClient() {
31+
if (sock_ != -1) {
32+
Close();
33+
}
34+
}
35+
36+
arrow::Status ArrowStreamClient::Connect() {
37+
size_t sep_pos = host_.find(':');
38+
if (sep_pos == std::string::npos || sep_pos == host_.size()) {
39+
return arrow::Status::Invalid(
40+
"Expected host to be in format <host>:<port> but got: " + host_);
41+
}
42+
std::string host_str = host_.substr(0, sep_pos);
43+
std::string port_str = host_.substr(sep_pos + 1, host_.size() - sep_pos);
44+
int port_num = std::stoi(port_str);
45+
struct sockaddr_in serv_addr;
46+
47+
if (sock_ == -1) {
48+
if ((sock_ = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
49+
return arrow::Status::IOError("Socket creation error");
50+
}
51+
}
52+
53+
bzero((char*)&serv_addr, sizeof(serv_addr));
54+
serv_addr.sin_addr.s_addr = inet_addr(host_str.c_str());
55+
serv_addr.sin_family = AF_INET;
56+
serv_addr.sin_port = htons(port_num);
57+
58+
if (connect(sock_, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) {
59+
return arrow::Status::IOError("Connection failed to host: " + host_);
60+
}
61+
62+
return arrow::Status::OK();
63+
}
64+
65+
arrow::Status ArrowStreamClient::Close() {
66+
int status = close(sock_);
67+
sock_ = 1;
68+
69+
if (status != 0) {
70+
return arrow::Status::IOError("Failed to correctly close connection");
71+
}
72+
73+
return arrow::Status::OK();
74+
}
75+
76+
arrow::Status ArrowStreamClient::Tell(int64_t* position) const {
77+
*position = pos_;
78+
return arrow::Status::OK();
79+
}
80+
81+
arrow::Status ArrowStreamClient::Read(int64_t nbytes,
82+
int64_t* bytes_read,
83+
void* out) {
84+
// TODO: look into why 0 bytes are requested
85+
if (nbytes == 0) {
86+
return arrow::Status::OK();
87+
}
88+
89+
int status = recv(sock_, out, nbytes, MSG_WAITALL);
90+
if (status == 0) {
91+
return arrow::Status::IOError("connection closed unexpectedly");
92+
} else if (status < 0) {
93+
return arrow::Status::IOError("error reading from socket");
94+
}
95+
96+
*bytes_read = nbytes;
97+
pos_ += *bytes_read;
98+
99+
return arrow::Status::OK();
100+
}
101+
102+
arrow::Status ArrowStreamClient::Read(int64_t nbytes,
103+
std::shared_ptr<arrow::Buffer>* out) {
104+
std::shared_ptr<arrow::ResizableBuffer> buffer;
105+
ARROW_RETURN_NOT_OK(arrow::AllocateResizableBuffer(nbytes, &buffer));
106+
int64_t bytes_read;
107+
ARROW_RETURN_NOT_OK(Read(nbytes, &bytes_read, buffer->mutable_data()));
108+
ARROW_RETURN_NOT_OK(buffer->Resize(bytes_read, false));
109+
buffer->ZeroPadding();
110+
*out = buffer;
111+
return arrow::Status::OK();
112+
}
113+
114+
} // namespace tensorflow

0 commit comments

Comments
 (0)