7
7
import asyncio
8
8
import time
9
9
import datetime
10
+ try :
11
+ from urllib import urlparse , unquote_plus , urlencode , quote_plus
12
+ except ImportError :
13
+ from urllib .parse import urlparse , unquote_plus , urlencode , quote_plus
10
14
11
15
from uamqp import authentication , constants , types , errors
12
16
from uamqp import (
13
17
Message ,
14
- Source ,
15
18
ConnectionAsync ,
16
19
AMQPClientAsync ,
17
20
SendClientAsync ,
@@ -37,7 +40,7 @@ class EventHubClientAsync(EventHubClient):
37
40
sending events to and receiving events from the Azure Event Hubs service.
38
41
"""
39
42
40
- def _create_auth (self , auth_uri , username , password ): # pylint: disable=no-self-use
43
+ def _create_auth (self , username = None , password = None ): # pylint: disable=no-self-use
41
44
"""
42
45
Create an ~uamqp.authentication.cbs_auth_async.SASTokenAuthAsync instance to authenticate
43
46
the session.
@@ -49,32 +52,13 @@ def _create_auth(self, auth_uri, username, password): # pylint: disable=no-self
49
52
:param password: The shared access key.
50
53
:type password: str
51
54
"""
55
+ username = username or self ._auth_config ['username' ]
56
+ password = password or self ._auth_config ['password' ]
52
57
if "@sas.root" in username :
53
- return authentication .SASLPlain (self .address .hostname , username , password )
54
- return authentication .SASTokenAsync .from_shared_access_key (auth_uri , username , password )
55
-
56
- def _create_connection_async (self ):
57
- """
58
- Create a new ~uamqp._async.connection_async.ConnectionAsync instance that will be shared between all
59
- AsyncSender/AsyncReceiver clients.
60
- """
61
- if not self .connection :
62
- log .info ("{}: Creating connection with address={}" .format (
63
- self .container_id , self .address .geturl ()))
64
- self .connection = ConnectionAsync (
65
- self .address .hostname ,
66
- self .auth ,
67
- container_id = self .container_id ,
68
- properties = self ._create_properties (),
69
- debug = self .debug )
70
-
71
- async def _close_connection_async (self ):
72
- """
73
- Close and destroy the connection async.
74
- """
75
- if self .connection :
76
- await self .connection .destroy_async ()
77
- self .connection = None
58
+ return authentication .SASLPlain (
59
+ self .address .hostname , username , password , http_proxy = self .http_proxy )
60
+ return authentication .SASTokenAsync .from_shared_access_key (
61
+ self .auth_uri , username , password , timeout = 60 , http_proxy = self .http_proxy )
78
62
79
63
async def _close_clients_async (self ):
80
64
"""
@@ -85,17 +69,13 @@ async def _close_clients_async(self):
85
69
async def _wait_for_client (self , client ):
86
70
try :
87
71
while client .get_handler_state ().value == 2 :
88
- await self . connection . work_async ()
72
+ await client . _handler . _connection . work_async () # pylint: disable=protected-access
89
73
except Exception as exp : # pylint: disable=broad-except
90
74
await client .close_async (exception = exp )
91
75
92
76
async def _start_client_async (self , client ):
93
77
try :
94
- await client .open_async (self .connection )
95
- started = await client .has_started ()
96
- while not started :
97
- await self .connection .work_async ()
98
- started = await client .has_started ()
78
+ await client .open_async ()
99
79
except Exception as exp : # pylint: disable=broad-except
100
80
await client .close_async (exception = exp )
101
81
@@ -108,9 +88,8 @@ async def _handle_redirect(self, redirects):
108
88
redirects = [c .redirected for c in self .clients if c .redirected ]
109
89
if not all (r .hostname == redirects [0 ].hostname for r in redirects ):
110
90
raise EventHubError ("Multiple clients attempting to redirect to different hosts." )
111
- self .auth = self ._create_auth (redirects [0 ].address .decode ('utf-8' ), ** self ._auth_config )
112
- await self .connection .redirect_async (redirects [0 ], self .auth )
113
- await asyncio .gather (* [c .open_async (self .connection ) for c in self .clients ])
91
+ self ._process_redirect_uri (redirects [0 ])
92
+ await asyncio .gather (* [c .open_async () for c in self .clients ])
114
93
115
94
async def run_async (self ):
116
95
"""
@@ -125,7 +104,6 @@ async def run_async(self):
125
104
:rtype: list[~azure.eventhub.common.EventHubError]
126
105
"""
127
106
log .info ("{}: Starting {} clients" .format (self .container_id , len (self .clients )))
128
- self ._create_connection_async ()
129
107
tasks = [self ._start_client_async (c ) for c in self .clients ]
130
108
try :
131
109
await asyncio .gather (* tasks )
@@ -153,18 +131,21 @@ async def stop_async(self):
153
131
log .info ("{}: Stopping {} clients" .format (self .container_id , len (self .clients )))
154
132
self .stopped = True
155
133
await self ._close_clients_async ()
156
- await self ._close_connection_async ()
157
134
158
135
async def get_eventhub_info_async (self ):
159
136
"""
160
137
Get details on the specified EventHub async.
161
138
162
139
:rtype: dict
163
140
"""
164
- eh_name = self .address .path .lstrip ('/' )
165
- target = "amqps://{}/{}" .format (self .address .hostname , eh_name )
166
- async with AMQPClientAsync (target , auth = self .auth , debug = self .debug ) as mgmt_client :
167
- mgmt_msg = Message (application_properties = {'name' : eh_name })
141
+ alt_creds = {
142
+ "username" : self ._auth_config .get ("iot_username" ),
143
+ "password" :self ._auth_config .get ("iot_password" )}
144
+ try :
145
+ mgmt_auth = self ._create_auth (** alt_creds )
146
+ mgmt_client = AMQPClientAsync (self .mgmt_target , auth = mgmt_auth , debug = self .debug )
147
+ await mgmt_client .open_async ()
148
+ mgmt_msg = Message (application_properties = {'name' : self .eh_name })
168
149
response = await mgmt_client .mgmt_request_async (
169
150
mgmt_msg ,
170
151
constants .READ_OPERATION ,
@@ -180,6 +161,8 @@ async def get_eventhub_info_async(self):
180
161
output ['partition_count' ] = eh_info [b'partition_count' ]
181
162
output ['partition_ids' ] = [p .decode ('utf-8' ) for p in eh_info [b'partition_ids' ]]
182
163
return output
164
+ finally :
165
+ await mgmt_client .close_async ()
183
166
184
167
def add_async_receiver (self , consumer_group , partition , offset = None , prefetch = 300 , operation = None , loop = None ):
185
168
"""
@@ -201,10 +184,7 @@ def add_async_receiver(self, consumer_group, partition, offset=None, prefetch=30
201
184
path = self .address .path + operation if operation else self .address .path
202
185
source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}" .format (
203
186
self .address .hostname , path , consumer_group , partition )
204
- source = Source (source_url )
205
- if offset is not None :
206
- source .set_filter (offset .selector ())
207
- handler = AsyncReceiver (self , source , prefetch = prefetch , loop = loop )
187
+ handler = AsyncReceiver (self , source_url , offset = offset , prefetch = prefetch , loop = loop )
208
188
self .clients .append (handler )
209
189
return handler
210
190
0 commit comments