diff --git a/Network/Socket.hs b/Network/Socket.hs index fc551f24..adf41f03 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -139,6 +139,7 @@ module Network.Socket -- * Socket , Socket , socket + , socketFromName , withFdSocket , unsafeFdSocket , touchSocket @@ -158,8 +159,14 @@ module Network.Socket -- ** Protocol number , ProtocolNumber , defaultProtocol + -- * Basic socket name type + , SockName(..) + , readSockName + , showSockName + , sockNameToAddr -- * Basic socket address type , SockAddr(..) + , sockAddrFamily , isSupportedSockAddr , getPeerName , getSocketName diff --git a/Network/Socket/Info.hsc b/Network/Socket/Info.hsc index 11cd6edb..df0a9574 100644 --- a/Network/Socket/Info.hsc +++ b/Network/Socket/Info.hsc @@ -8,14 +8,17 @@ module Network.Socket.Info where +import Control.Exception (try, IOException) import Foreign.Marshal.Alloc (alloca, allocaBytes) import Foreign.Marshal.Utils (maybeWith, with) import GHC.IO (unsafePerformIO) import GHC.IO.Exception (IOErrorType(NoSuchThing)) import System.IO.Error (ioeSetErrorString, mkIOError) +import Text.Read (readEither) import Network.Socket.Imports import Network.Socket.Internal +import Network.Socket.Syscall (socket) import Network.Socket.Types ----------------------------------------------------------------------------- @@ -262,7 +265,9 @@ showDefaultHints AddrInfo{..} = concat [ -- -- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http") -- >>> addrAddress addr --- 127.0.0.1:80 +-- SockAddrInet 80 16777343 +-- >>> showSockAddr (addrAddress addr) +-- "127.0.0.1:80" getAddrInfo :: Maybe AddrInfo -- ^ preferred socket type or protocol @@ -433,24 +438,89 @@ unpackBits ((k,v):xs) r | otherwise = unpackBits xs r ----------------------------------------------------------------------------- --- SockAddr - -instance Show SockAddr where -#if defined(DOMAIN_SOCKET_SUPPORT) - showsPrec _ (SockAddrUnix str) = showString str -#else - showsPrec _ SockAddrUnix{} = error "showsPrec: not supported" -#endif - showsPrec _ addr@(SockAddrInet port _) +-- SockName, SockAddr + +-- | Read a string representing a socket name. +readSockName :: PortNumber -> String -> Either String SockName +readSockName defPort hostport = case hostport of + '/':_ -> Right $ SockAddr $ SockAddrUnix hostport + '[':tl -> case span ((/=) ']') tl of + (_, []) -> Left $ "unterminated IPv6 address: " <> hostport + (ipv6, _:port) -> case readAddr ipv6 of + Nothing -> Left $ "invalid IPv6 address: " <> ipv6 + Just addr -> SockAddr . sockAddrPort addr <$> readPort port + _ -> case span ((/=) ':') hostport of + (host, port) -> case readAddr host of + Nothing -> SockName host <$> readPort port + Just addr -> SockAddr . sockAddrPort addr <$> readPort port + where + readPort "" = Right defPort + readPort ":" = Right defPort + readPort (':':port) = case readEither port of + Right p -> Right p + Left _ -> Left $ "bad port: " <> port + readPort x = Left $ "bad port: " <> x + hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] } + readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of + Left e -> Nothing where _ = e :: IOException + Right r -> Just (addrAddress (head r)) + sockAddrPort h p = case h of + SockAddrInet _ a -> SockAddrInet p a + SockAddrInet6 _ f a s -> SockAddrInet6 p f a s + x -> x + +showSockName :: SockName -> String +showSockName n = case n of + SockName h p -> h <> ":" <> show p + SockAddr a -> showSockAddr a + +-- | Read a string representing a socket address. +readSockAddr :: PortNumber -> String -> Either String SockAddr +readSockAddr defPort hostport = readSockName defPort hostport >>= \r -> case r of + SockName h _ -> Left $ "expected address but got hostname: " <> h + SockAddr a -> Right a + +showSockAddr :: SockAddr -> String +showSockAddr sa = showsAddr sa "" where + showsAddr (SockAddrUnix str) = showString str + showsAddr addr@(SockAddrInet port _) = showString (unsafePerformIO $ fst <$> getNameInfo [NI_NUMERICHOST] True False addr >>= maybe (fail "showsPrec: impossible internal error") return) . showString ":" . shows port - showsPrec _ addr@(SockAddrInet6 port _ _ _) + showsAddr addr@(SockAddrInet6 port _ _ _) = showChar '[' . showString (unsafePerformIO $ fst <$> getNameInfo [NI_NUMERICHOST] True False addr >>= maybe (fail "showsPrec: impossible internal error") return) . showString "]:" . shows port + +-- | Resolve a socket name into a list of socket addresses. +-- The result is always non-empty; Haskell throws an exception if name +-- resolution fails. +sockNameToAddr :: SockName -> IO [SockAddr] +sockNameToAddr name = case name of + SockAddr a -> pure [a] + SockName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port)) + where + hints = Just $ defaultHints { addrSocketType = Stream } + -- prevents duplicates, otherwise getAddrInfo returns all socket types + +-- | Shortcut for creating a socket from a socket name. +-- +-- >>> import Network.Socket +-- >>> let Right sn = readSockName 0 "0.0.0.0:0" +-- >>> (s, a) <- socketFromName sn head Stream defaultProtocol +-- >>> bind s a +socketFromName + :: SockName + -> ([SockAddr] -> SockAddr) + -> SocketType + -> ProtocolNumber + -> IO (Socket, SockAddr) +socketFromName sname select stype protocol = do + a <- select <$> sockNameToAddr sname + s <- socket (sockAddrFamily a) stype protocol + pure (s, a) diff --git a/Network/Socket/SockAddr.hs b/Network/Socket/SockAddr.hs index a16b2e2b..3aae8eec 100644 --- a/Network/Socket/SockAddr.hs +++ b/Network/Socket/SockAddr.hs @@ -8,6 +8,10 @@ module Network.Socket.SockAddr ( , recvBufFrom ) where +import Control.Exception (try, throwIO, IOException) +import System.Directory (removeFile) +import System.IO.Error (isAlreadyInUseError, isDoesNotExistError) + import qualified Network.Socket.Buffer as G import qualified Network.Socket.Name as G import qualified Network.Socket.Syscall as G @@ -32,7 +36,27 @@ connect = G.connect -- 'defaultPort' is passed then the system assigns the next available -- use port. bind :: Socket -> SockAddr -> IO () -bind = G.bind +bind s a = case a of + SockAddrUnix p -> do + -- gracefully handle the fact that UNIX systems don't clean up closed UNIX + -- domain sockets, inspired by https://stackoverflow.com/a/13719866 + res <- try (G.bind s a) + case res of + Right () -> pure () + Left e -> if not (isAlreadyInUseError (e :: IOException)) + then throwIO e + else do + -- socket might be in use, try to connect + res2 <- try (G.connect s a) + case res2 of + Right () -> close s >> throwIO e + Left e2 -> if not (isDoesNotExistError (e2 :: IOException)) + then throwIO e + else do + -- socket not actually in use, remove it and retry bind + removeFile p + G.bind s a + _ -> G.bind s a -- | Accept a connection. The socket must be bound to an address and -- listening for connections. The return value is a pair @(conn, diff --git a/Network/Socket/Syscall.hs b/Network/Socket/Syscall.hs index 57e3e34a..31ef4837 100644 --- a/Network/Socket/Syscall.hs +++ b/Network/Socket/Syscall.hs @@ -68,7 +68,7 @@ import Network.Socket.Types -- >>> sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) -- >>> Network.Socket.bind sock (addrAddress addr) -- >>> getSocketName sock --- 127.0.0.1:5000 +-- SockAddrInet 5000 16777343 socket :: Family -- Family Name (usually AF_INET) -> SocketType -- Socket Type (usually Stream) -> ProtocolNumber -- Protocol Number (getProtocolByName to find value) diff --git a/Network/Socket/Types.hsc b/Network/Socket/Types.hsc index bcb42950..3a0d82c4 100644 --- a/Network/Socket/Types.hsc +++ b/Network/Socket/Types.hsc @@ -42,7 +42,9 @@ module Network.Socket.Types ( , withNewSocketAddress -- * Socket address type + , SockName(..) , SockAddr(..) + , sockAddrFamily , isSupportedSockAddr , HostAddress , hostAddressToTuple @@ -970,6 +972,17 @@ type FlowInfo = Word32 -- | Scope identifier. type ScopeID = Word32 +-- | Socket names. +-- A wrapper around socket addresses that also accommodates the +-- popular usage of specifying them by name, e.g. "example.com:80". +-- Note that we don't support service names here because they also +-- imply a particular socket type, which is outside of the scope of +-- what this data type represents. +data SockName + = SockName !String !PortNumber + | SockAddr !SockAddr + deriving (Eq, Ord, Typeable, Read, Show) + -- | Socket addresses. -- The existence of a constructor does not necessarily imply that -- that socket address type is supported on your system: see @@ -986,13 +999,19 @@ data SockAddr -- | The path must have fewer than 104 characters. All of these characters must have code points less than 256. | SockAddrUnix String -- sun_path - deriving (Eq, Ord, Typeable) + deriving (Eq, Ord, Typeable, Read, Show) instance NFData SockAddr where rnf (SockAddrInet _ _) = () rnf (SockAddrInet6 _ _ _ _) = () rnf (SockAddrUnix str) = rnf str +sockAddrFamily :: SockAddr -> Family +sockAddrFamily addr = case addr of + SockAddrInet _ _ -> AF_INET + SockAddrInet6 _ _ _ _ -> AF_INET6 + SockAddrUnix _ -> AF_UNIX + -- | Is the socket address type supported on this system? isSupportedSockAddr :: SockAddr -> Bool isSupportedSockAddr addr = case addr of diff --git a/network.cabal b/network.cabal index 713e7e92..89f4e364 100644 --- a/network.cabal +++ b/network.cabal @@ -83,7 +83,8 @@ library build-depends: base >= 4.7 && < 5, bytestring == 0.10.*, - deepseq + deepseq, + directory include-dirs: include includes: HsNet.h HsNetDef.h