diff --git a/tftpy/TftpClient.py b/tftpy/TftpClient.py index eb82c05..dbd5d4d 100644 --- a/tftpy/TftpClient.py +++ b/tftpy/TftpClient.py @@ -3,8 +3,7 @@ """This module implements the TFTP Client functionality. Instantiate an instance of the client, and then use its upload or download method. Logging is performed via a standard logging object set in TftpShared.""" - - +import socket import types import logging from .TftpShared import * @@ -18,7 +17,7 @@ class TftpClient(TftpSession): download can be initiated via the download() method, or an upload via the upload() method.""" - def __init__(self, host, port=69, options={}, localip = ""): + def __init__(self, host, port=69, options={}, localip = "", af_family=socket.AF_INET): TftpSession.__init__(self) self.context = None self.host = host @@ -26,6 +25,7 @@ def __init__(self, host, port=69, options={}, localip = ""): self.filename = None self.options = options self.localip = localip + self.af_family = af_family if 'blksize' in self.options: size = self.options['blksize'] tftpassert(int == type(size), "blksize must be an int") @@ -54,7 +54,8 @@ def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT): self.options, packethook, timeout, - localip = self.localip) + localip=self.localip, + af_family=self.af_family) self.context.start() # Download happens here self.context.end() diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index da85886..9ee627a 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -76,14 +76,15 @@ def add_dup(self, pkt): class TftpContext(object): """The base class of the contexts.""" - def __init__(self, host, port, timeout, localip = ""): + def __init__(self, host, port, timeout, localip = "", af_family=socket.AF_INET): """Constructor for the base context, setting shared instance variables.""" self.file_to_transfer = None self.fileobj = None self.options = None self.packethook = None - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.af_family = af_family + self.sock = socket.socket(af_family, socket.SOCK_DGRAM) if localip != "": self.sock.bind((localip, 0)) self.sock.settimeout(timeout) @@ -147,7 +148,13 @@ def sethost(self, host): """Setter method that also sets the address property as a result of the host that is set.""" self.__host = host - self.address = socket.gethostbyname(host) + if self.af_family == socket.AF_INET: + self.address = socket.gethostbyname(host) + elif self.af_family == socket.AF_INET6: + self.address = socket.getaddrinfo(host, 0)[0][4][0] + else: + raise ValueError("AF Family is not supported") + host = property(gethost, sethost) @@ -166,7 +173,12 @@ def cycle(self): """Here we wait for a response from the server after sending it something, and dispatch appropriate action to that response.""" try: - (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + if self.af_family == socket.AF_INET: + (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) + elif self.af_family == socket.AF_INET6: + (buffer, (raddress, rport, _, _)) = self.sock.recvfrom(MAX_BLKSIZE) + else: + raise ValueError("Socket familiy is not supported") except socket.timeout: log.warning("Timeout waiting for traffic, retrying...") raise TftpTimeout("Timed-out waiting for traffic") @@ -212,11 +224,13 @@ def __init__(self, timeout, root, dyn_file_func=None, - upload_open=None): + upload_open=None, + af_family=socket.AF_INET): TftpContext.__init__(self, host, port, timeout, + af_family=af_family ) # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. @@ -346,12 +360,14 @@ def __init__(self, options, packethook, timeout, - localip = ""): + localip = "", + af_family=socket.AF_INET): TftpContext.__init__(self, host, port, timeout, - localip) + localip, + af_family=af_family) # FIXME: should we refactor setting of these params? self.file_to_transfer = filename self.options = options diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 95ca70e..8be85b8 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -77,7 +77,7 @@ def __init__(self, raise TftpException("The tftproot does not exist.") def listen(self, listenip="", listenport=DEF_TFTP_PORT, - timeout=SOCK_TIMEOUT): + timeout=SOCK_TIMEOUT, af_family=socket.AF_INET): """Start a server listening on the supplied interface and port. This defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also supply a different socket timeout value, if desired.""" @@ -85,13 +85,21 @@ def listen(self, listenip="", listenport=DEF_TFTP_PORT, # Don't use new 2.5 ternary operator yet # listenip = listenip if listenip else '0.0.0.0' - if not listenip: listenip = '0.0.0.0' + if not listenip: + listenip = '0.0.0.0' log.info("Server requested on ip %s, port %s" % (listenip, listenport)) try: # FIXME - sockets should be non-blocking - self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.sock.bind((listenip, listenport)) - _, self.listenport = self.sock.getsockname() + self.sock = socket.socket(af_family, socket.SOCK_DGRAM) + if af_family == socket.AF_INET: + self.sock.bind((listenip, listenport)) + _, self.listenport = self.sock.getsockname() + elif af_family == socket.AF_INET6: + self.sock.bind((listenip, listenport)) + _, self.listenport, _, _ = self.sock.getsockname() + else: + log.error("Socket family %d is not supported", af_family) + raise ValueError("Socket family is not supported") except socket.error as err: # Reraise it for now. raise err @@ -144,7 +152,10 @@ def listen(self, listenip="", listenport=DEF_TFTP_PORT, # Is the traffic on the main server socket? ie. new session? if readysock == self.sock: log.debug("Data ready on our main socket") - buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) + if self.af_family == socket.AF_INET: + buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE) + else: + buffer, (raddress, rport, _, _) = self.sock.recvfrom(MAX_BLKSIZE) log.debug("Read %d bytes", len(buffer)) @@ -165,7 +176,8 @@ def listen(self, listenip="", listenport=DEF_TFTP_PORT, timeout, self.root, self.dyn_file_func, - self.upload_open) + self.upload_open, + af_family=af_family) try: self.sessions[key].start(buffer) except TftpException as err: