Skip to content

Implement socket names, in particular reading a string description of them #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ module Network.Socket
-- * Socket
, Socket
, socket
, socketFromName
, withFdSocket
, unsafeFdSocket
, touchSocket
Expand All @@ -158,8 +159,14 @@ module Network.Socket
-- ** Protocol number
, ProtocolNumber
, defaultProtocol
-- * Basic socket name type
, SockName(..)
, readSockName
, showSockName
, sockNameToAddr
-- * Basic socket address type
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, getPeerName
, getSocketName
Expand Down
92 changes: 81 additions & 11 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@

module Network.Socket.Info where

import Control.Exception (try, IOException)
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Marshal.Utils (maybeWith, with)
import GHC.IO (unsafePerformIO)
import GHC.IO.Exception (IOErrorType(NoSuchThing))
import System.IO.Error (ioeSetErrorString, mkIOError)
import Text.Read (readEither)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Syscall (socket)
import Network.Socket.Types

-----------------------------------------------------------------------------
Expand Down Expand Up @@ -262,7 +265,9 @@ showDefaultHints AddrInfo{..} = concat [
--
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http")
-- >>> addrAddress addr
-- 127.0.0.1:80
-- SockAddrInet 80 16777343
-- >>> showSockAddr (addrAddress addr)
-- "127.0.0.1:80"

getAddrInfo
:: Maybe AddrInfo -- ^ preferred socket type or protocol
Expand Down Expand Up @@ -433,24 +438,89 @@ unpackBits ((k,v):xs) r
| otherwise = unpackBits xs r

-----------------------------------------------------------------------------
-- SockAddr

instance Show SockAddr where
#if defined(DOMAIN_SOCKET_SUPPORT)
showsPrec _ (SockAddrUnix str) = showString str
#else
showsPrec _ SockAddrUnix{} = error "showsPrec: not supported"
#endif
showsPrec _ addr@(SockAddrInet port _)
-- SockName, SockAddr

-- | Read a string representing a socket name.
readSockName :: PortNumber -> String -> Either String SockName
readSockName defPort hostport = case hostport of
'/':_ -> Right $ SockAddr $ SockAddrUnix hostport
'[':tl -> case span ((/=) ']') tl of
(_, []) -> Left $ "unterminated IPv6 address: " <> hostport
(ipv6, _:port) -> case readAddr ipv6 of
Nothing -> Left $ "invalid IPv6 address: " <> ipv6
Just addr -> SockAddr . sockAddrPort addr <$> readPort port
_ -> case span ((/=) ':') hostport of
(host, port) -> case readAddr host of
Nothing -> SockName host <$> readPort port
Just addr -> SockAddr . sockAddrPort addr <$> readPort port
where
readPort "" = Right defPort
readPort ":" = Right defPort
readPort (':':port) = case readEither port of
Right p -> Right p
Left _ -> Left $ "bad port: " <> port
readPort x = Left $ "bad port: " <> x
hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] }
readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of
Left e -> Nothing where _ = e :: IOException
Right r -> Just (addrAddress (head r))
sockAddrPort h p = case h of
SockAddrInet _ a -> SockAddrInet p a
SockAddrInet6 _ f a s -> SockAddrInet6 p f a s
x -> x

showSockName :: SockName -> String
showSockName n = case n of
SockName h p -> h <> ":" <> show p
SockAddr a -> showSockAddr a

-- | Read a string representing a socket address.
readSockAddr :: PortNumber -> String -> Either String SockAddr
readSockAddr defPort hostport = readSockName defPort hostport >>= \r -> case r of
SockName h _ -> Left $ "expected address but got hostname: " <> h
SockAddr a -> Right a

showSockAddr :: SockAddr -> String
showSockAddr sa = showsAddr sa "" where
showsAddr (SockAddrUnix str) = showString str
showsAddr addr@(SockAddrInet port _)
= showString (unsafePerformIO $
fst <$> getNameInfo [NI_NUMERICHOST] True False addr >>=
maybe (fail "showsPrec: impossible internal error") return)
. showString ":"
. shows port
showsPrec _ addr@(SockAddrInet6 port _ _ _)
showsAddr addr@(SockAddrInet6 port _ _ _)
= showChar '['
. showString (unsafePerformIO $
fst <$> getNameInfo [NI_NUMERICHOST] True False addr >>=
maybe (fail "showsPrec: impossible internal error") return)
. showString "]:"
. shows port

-- | Resolve a socket name into a list of socket addresses.
-- The result is always non-empty; Haskell throws an exception if name
-- resolution fails.
sockNameToAddr :: SockName -> IO [SockAddr]
sockNameToAddr name = case name of
SockAddr a -> pure [a]
SockName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port))
where
hints = Just $ defaultHints { addrSocketType = Stream }
-- prevents duplicates, otherwise getAddrInfo returns all socket types

-- | Shortcut for creating a socket from a socket name.
--
-- >>> import Network.Socket
-- >>> let Right sn = readSockName 0 "0.0.0.0:0"
-- >>> (s, a) <- socketFromName sn head Stream defaultProtocol
-- >>> bind s a
socketFromName
:: SockName
-> ([SockAddr] -> SockAddr)
-> SocketType
-> ProtocolNumber
-> IO (Socket, SockAddr)
socketFromName sname select stype protocol = do
a <- select <$> sockNameToAddr sname
s <- socket (sockAddrFamily a) stype protocol
pure (s, a)
26 changes: 25 additions & 1 deletion Network/Socket/SockAddr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ module Network.Socket.SockAddr (
, recvBufFrom
) where

import Control.Exception (try, throwIO, IOException)
import System.Directory (removeFile)
import System.IO.Error (isAlreadyInUseError, isDoesNotExistError)

import qualified Network.Socket.Buffer as G
import qualified Network.Socket.Name as G
import qualified Network.Socket.Syscall as G
Expand All @@ -32,7 +36,27 @@ connect = G.connect
-- 'defaultPort' is passed then the system assigns the next available
-- use port.
bind :: Socket -> SockAddr -> IO ()
bind = G.bind
bind s a = case a of
SockAddrUnix p -> do
-- gracefully handle the fact that UNIX systems don't clean up closed UNIX
-- domain sockets, inspired by https://stackoverflow.com/a/13719866
res <- try (G.bind s a)
case res of
Right () -> pure ()
Left e -> if not (isAlreadyInUseError (e :: IOException))
then throwIO e
else do
-- socket might be in use, try to connect
res2 <- try (G.connect s a)
case res2 of
Right () -> close s >> throwIO e
Left e2 -> if not (isDoesNotExistError (e2 :: IOException))
then throwIO e
else do
-- socket not actually in use, remove it and retry bind
removeFile p
G.bind s a
_ -> G.bind s a

-- | Accept a connection. The socket must be bound to an address and
-- listening for connections. The return value is a pair @(conn,
Expand Down
2 changes: 1 addition & 1 deletion Network/Socket/Syscall.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ import Network.Socket.Types
-- >>> sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
-- >>> Network.Socket.bind sock (addrAddress addr)
-- >>> getSocketName sock
-- 127.0.0.1:5000
-- SockAddrInet 5000 16777343
socket :: Family -- Family Name (usually AF_INET)
-> SocketType -- Socket Type (usually Stream)
-> ProtocolNumber -- Protocol Number (getProtocolByName to find value)
Expand Down
21 changes: 20 additions & 1 deletion Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ module Network.Socket.Types (
, withNewSocketAddress

-- * Socket address type
, SockName(..)
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, HostAddress
, hostAddressToTuple
Expand Down Expand Up @@ -970,6 +972,17 @@ type FlowInfo = Word32
-- | Scope identifier.
type ScopeID = Word32

-- | Socket names.
-- A wrapper around socket addresses that also accommodates the
-- popular usage of specifying them by name, e.g. "example.com:80".
-- Note that we don't support service names here because they also
-- imply a particular socket type, which is outside of the scope of
-- what this data type represents.
data SockName
= SockName !String !PortNumber
| SockAddr !SockAddr
deriving (Eq, Ord, Typeable, Read, Show)

-- | Socket addresses.
-- The existence of a constructor does not necessarily imply that
-- that socket address type is supported on your system: see
Expand All @@ -986,13 +999,19 @@ data SockAddr
-- | The path must have fewer than 104 characters. All of these characters must have code points less than 256.
| SockAddrUnix
String -- sun_path
deriving (Eq, Ord, Typeable)
deriving (Eq, Ord, Typeable, Read, Show)

instance NFData SockAddr where
rnf (SockAddrInet _ _) = ()
rnf (SockAddrInet6 _ _ _ _) = ()
rnf (SockAddrUnix str) = rnf str

sockAddrFamily :: SockAddr -> Family
sockAddrFamily addr = case addr of
SockAddrInet _ _ -> AF_INET
SockAddrInet6 _ _ _ _ -> AF_INET6
SockAddrUnix _ -> AF_UNIX

-- | Is the socket address type supported on this system?
isSupportedSockAddr :: SockAddr -> Bool
isSupportedSockAddr addr = case addr of
Expand Down
3 changes: 2 additions & 1 deletion network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ library
build-depends:
base >= 4.7 && < 5,
bytestring == 0.10.*,
deepseq
deepseq,
directory

include-dirs: include
includes: HsNet.h HsNetDef.h
Expand Down