blocxx

Win32SocketBaseImpl.cpp

Go to the documentation of this file.
00001 /*******************************************************************************
00002 * Copyright (C) 2005, Vintela, Inc. All rights reserved.
00003 * Copyright (C) 2006, Novell, Inc. All rights reserved.
00004 * 
00005 * Redistribution and use in source and binary forms, with or without
00006 * modification, are permitted provided that the following conditions are met:
00007 * 
00008 *     * Redistributions of source code must retain the above copyright notice,
00009 *       this list of conditions and the following disclaimer.
00010 *     * Redistributions in binary form must reproduce the above copyright
00011 *       notice, this list of conditions and the following disclaimer in the
00012 *       documentation and/or other materials provided with the distribution.
00013 *     * Neither the name of 
00014 *       Vintela, Inc., 
00015 *       nor Novell, Inc., 
00016 *       nor the names of its contributors or employees may be used to 
00017 *       endorse or promote products derived from this software without 
00018 *       specific prior written permission.
00019 * 
00020 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00021 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00022 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00023 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00024 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00025 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00026 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00027 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00028 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00029 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
00030 * POSSIBILITY OF SUCH DAMAGE.
00031 *******************************************************************************/
00032 
00033 
00040 #include "blocxx/BLOCXX_config.h"
00041 
00042 #if defined(BLOCXX_WIN32)
00043 
00044 #include "blocxx/SocketBaseImpl.hpp"
00045 #include "blocxx/SocketUtils.hpp"
00046 #include "blocxx/Format.hpp"
00047 #include "blocxx/Assertion.hpp"
00048 #include "blocxx/IOException.hpp"
00049 #include "blocxx/Mutex.hpp"
00050 #include "blocxx/MutexLock.hpp"
00051 #include "blocxx/Socket.hpp"
00052 #include "blocxx/Thread.hpp"
00053 #include "blocxx/System.hpp"
00054 #include "blocxx/TimeoutTimer.hpp"
00055 
00056 #include <cstdio>
00057 #include <cerrno>
00058 #include <fstream>
00059 #include <ws2tcpip.h>
00060 
00061 namespace
00062 {
00063 
00064 class SockInitializer
00065 {
00066 public:
00067    SockInitializer()
00068    {
00069       WSADATA wsaData;
00070       ::WSAStartup(MAKEWORD(2,2), &wsaData);
00071    }
00072 
00073    ~SockInitializer()
00074    {
00075       ::WSACleanup();
00076    }
00077 };
00078 
00079 // Force Winsock initialization on load
00080 SockInitializer _sockInitializer;
00081 
00083 void
00084 _closeSocket(SOCKET& sockfd)
00085 {
00086    if (sockfd != INVALID_SOCKET)
00087    {
00088       ::closesocket(sockfd);
00089       sockfd = INVALID_SOCKET;
00090    }
00091 }
00092 
00094 int
00095 getAddrFromIface(BLOCXX_NAMESPACE::InetSocketAddress_t& addr)
00096 {
00097    SOCKET sd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
00098    if (sd == SOCKET_ERROR)
00099    {
00100       return -1;
00101    }
00102 
00103    int cc = -1;
00104    INTERFACE_INFO interfaceList[20];
00105    unsigned long nBytesReturned;
00106    if (::WSAIoctl(sd, SIO_GET_INTERFACE_LIST, 0, 0, &interfaceList,
00107          sizeof(interfaceList), &nBytesReturned, 0, 0) != SOCKET_ERROR)
00108    {
00109       int nNumInterfaces = nBytesReturned / sizeof(INTERFACE_INFO);
00110       for (int i = 0; i < nNumInterfaces; ++i)
00111       {
00112          u_long nFlags = interfaceList[i].iiFlags;
00113          if (nFlags & IFF_UP)
00114          {
00115             cc = 0;
00116             ::memcpy(&addr, &(interfaceList[i].iiAddress), sizeof(addr));
00117             if (!(nFlags & IFF_LOOPBACK))
00118             {
00119                break;
00120             }
00121          }
00122       }
00123    }
00124 
00125    ::closesocket(sd);
00126    return 0;
00127 }
00128 
00129 }  // end of unnamed namespace
00130 
00131 namespace BLOCXX_NAMESPACE
00132 {
00133 
00134 using std::istream;
00135 using std::ostream;
00136 using std::iostream;
00137 using std::ifstream;
00138 using std::ofstream;
00139 using std::fstream;
00140 using std::ios;
00141 String SocketBaseImpl::m_traceFileOut;
00142 String SocketBaseImpl::m_traceFileIn;
00143 
00145 // static
00146 int
00147 SocketBaseImpl::waitForEvent(HANDLE eventArg, int secsToTimeout)
00148 {
00149    DWORD timeout = (secsToTimeout != -1)
00150       ? static_cast<DWORD>(secsToTimeout * 1000)
00151       : INFINITE;
00152    
00153    int cc;
00154    if(Socket::getShutDownMechanism() != NULL)
00155    {
00156       HANDLE events[2];
00157       events[0] = Socket::getShutDownMechanism();
00158       events[1] = eventArg;
00159 
00160       DWORD index = ::WaitForMultipleObjects(
00161          2,
00162          events,
00163          FALSE,
00164          timeout);
00165 
00166       switch (index)
00167       {
00168          case WAIT_FAILED:
00169             cc = -2;
00170             break;
00171          case WAIT_TIMEOUT:
00172             cc = -1;
00173             break;
00174          default:
00175             index -= WAIT_OBJECT_0;
00176             // If not shutdown event, then reset
00177             if (index != 0)
00178             {
00179                ::ResetEvent(eventArg);
00180             }
00181             cc = static_cast<int>(index);
00182             break;
00183       }
00184    }
00185    else
00186    {
00187       switch(::WaitForSingleObject(eventArg, timeout))
00188       {
00189          case WAIT_OBJECT_0:
00190             ::ResetEvent(eventArg);
00191             cc = 1;
00192             break;
00193          case WAIT_TIMEOUT:
00194             cc = -1;
00195             break;
00196          default:
00197             cc = -2;
00198             break;
00199       }
00200    }
00201       
00202    return cc;
00203 }
00204 
00205 #pragma warning (push)
00206 #pragma warning (disable: 4355)
00207 
00209 SocketBaseImpl::SocketBaseImpl()
00210    : SelectableIFC()
00211    , IOIFC()
00212    , m_isConnected(false)
00213    , m_sockfd(INVALID_SOCKET)
00214    , m_localAddress()
00215    , m_peerAddress()
00216    , m_event(NULL)
00217    , m_recvTimeoutExprd(false)
00218    , m_streamBuf(this)
00219    , m_in(&m_streamBuf)
00220    , m_out(&m_streamBuf)
00221    , m_inout(&m_streamBuf)
00222    , m_recvTimeout(Timeout::infinite)
00223    , m_sendTimeout(Timeout::infinite)
00224    , m_connectTimeout(Timeout::relative(0))
00225 {
00226    m_out.exceptions(std::ios::badbit);
00227    m_inout.exceptions(std::ios::badbit);
00228    m_event = ::CreateEvent(NULL, TRUE, FALSE, NULL);
00229    BLOCXX_ASSERT(m_event != NULL);
00230 }
00232 SocketBaseImpl::SocketBaseImpl(SocketHandle_t fd,
00233       SocketAddress::AddressType addrType)
00234    : SelectableIFC()
00235    , IOIFC()
00236    , m_isConnected(true)
00237    , m_sockfd(fd)
00238    , m_localAddress(SocketAddress::getAnyLocalHost())
00239    , m_peerAddress(SocketAddress::allocEmptyAddress(addrType))
00240    , m_event(NULL)
00241    , m_recvTimeoutExprd(false)
00242    , m_streamBuf(this)
00243    , m_in(&m_streamBuf)
00244    , m_out(&m_streamBuf)
00245    , m_inout(&m_streamBuf)
00246    , m_recvTimeout(Timeout::infinite)
00247    , m_sendTimeout(Timeout::infinite)
00248    , m_connectTimeout(Timeout::relative(0))
00249 {
00250    BLOCXX_ASSERT(addrType == SocketAddress::INET);
00251 
00252    m_out.exceptions(std::ios::badbit);
00253    m_inout.exceptions(std::ios::badbit);
00254    m_event = ::CreateEvent(NULL, TRUE, FALSE, NULL);
00255    BLOCXX_ASSERT(m_event != NULL);
00256    fillInetAddrParms();
00257 }
00259 SocketBaseImpl::SocketBaseImpl(const SocketAddress& addr)
00260    : SelectableIFC()
00261    , IOIFC()
00262    , m_isConnected(false)
00263    , m_sockfd(INVALID_SOCKET)
00264    , m_localAddress(SocketAddress::getAnyLocalHost())
00265    , m_peerAddress(addr)
00266    , m_event(NULL)
00267    , m_recvTimeoutExprd(false)
00268    , m_streamBuf(this)
00269    , m_in(&m_streamBuf)
00270    , m_out(&m_streamBuf)
00271    , m_inout(&m_streamBuf)
00272    , m_recvTimeout(Timeout::infinite)
00273    , m_sendTimeout(Timeout::infinite)
00274    , m_connectTimeout(Timeout::relative(0))
00275 {
00276    m_out.exceptions(std::ios::badbit);
00277    m_inout.exceptions(std::ios::badbit);
00278    m_event = ::CreateEvent(NULL, TRUE, FALSE, NULL);
00279    BLOCXX_ASSERT(m_event != NULL);
00280    connect(m_peerAddress);
00281 }
00282 
00283 #pragma warning (pop)
00284 
00286 SocketBaseImpl::~SocketBaseImpl()
00287 {
00288    try
00289    {
00290       disconnect();
00291    }
00292    catch (...)
00293    {
00294       // don't let exceptions escape
00295    }
00296    ::CloseHandle(m_event);
00297 }
00299 Select_t
00300 SocketBaseImpl::getSelectObj() const
00301 {
00302    Select_t st;
00303    st.event = m_event;
00304    st.sockfd = m_sockfd;
00305    st.isSocket = true;
00306    st.networkevents = FD_READ | FD_WRITE;
00307    st.doreset = true;
00308    return st;
00309 }
00311 void
00312 SocketBaseImpl::connect(const SocketAddress& addr)
00313 {
00314    if (m_isConnected)
00315    {
00316       disconnect();
00317    }
00318    m_streamBuf.reset();
00319    m_in.clear();
00320    m_out.clear();
00321    m_inout.clear();
00322    BLOCXX_ASSERT(addr.getType() == SocketAddress::INET);
00323 
00324    m_sockfd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
00325    if (m_sockfd == INVALID_SOCKET)
00326    {
00327       BLOCXX_THROW(SocketException, 
00328          Format("Failed to create a socket: %1",
00329          System::lastErrorMsg(true)).c_str());
00330    }
00331 
00332    int cc;
00333    WSANETWORKEVENTS networkEvents;
00334 
00335    // Connect non-blocking
00336    if(::WSAEventSelect(m_sockfd, m_event, FD_CONNECT) != 0)
00337    {
00338       BLOCXX_THROW(SocketException, 
00339          Format("WSAEventSelect Failed: %1",
00340          System::lastErrorMsg(true)).c_str());
00341    }
00342 
00343    if (::connect(m_sockfd, addr.getNativeForm(), addr.getNativeFormSize())
00344       == SOCKET_ERROR)
00345    {
00346       int lastError = ::WSAGetLastError();
00347       if (lastError != WSAEWOULDBLOCK && lastError != WSAEINPROGRESS)
00348       {
00349          _closeSocket(m_sockfd);
00350          BLOCXX_THROW(SocketException,
00351             Format("Failed to connect to: %1: %2(%3)", addr.toString(),
00352                lastError, System::lastErrorMsg(true)).c_str());
00353       }
00354 
00355       TimeoutTimer timer(m_connectTimeout);
00356       int tmoutval = timer.asDWORDMs();
00357 
00358       // Wait for connection event to come through
00359       while (true)
00360       {
00361          // Wait for the socket's event to get signaled
00362          if ((cc = waitForEvent(m_event, tmoutval)) < 1)
00363          {
00364             _closeSocket(m_sockfd);
00365             switch (cc)
00366             {
00367                case 0:     // Shutdown event
00368                   BLOCXX_THROW(SocketException,
00369                      "Sockets have been shutdown");
00370                case -1: // Timed out
00371                   BLOCXX_THROW(SocketException,
00372                      Format("Win32SocketBaseImpl connection"
00373                         " timed out. Timeout val = %1",
00374                         tmoutval).c_str());
00375                default: // Error on wait
00376                   BLOCXX_THROW(SocketException, Format("SocketBaseImpl::"
00377                      "connect() wait failed: %1(%2)",
00378                      ::WSAGetLastError(),
00379                      System::lastErrorMsg(true)).c_str());
00380             }
00381          }
00382 
00383          // Find out what network event took place
00384          if (::WSAEnumNetworkEvents(m_sockfd, m_event, &networkEvents)
00385             == SOCKET_ERROR)
00386          {
00387             _closeSocket(m_sockfd);
00388             BLOCXX_THROW(SocketException,
00389                Format("SocketBaseImpl::connect()"
00390                   " failed getting network events: %1(%2)",
00391                   ::WSAGetLastError(),
00392                   System::lastErrorMsg(true)).c_str());
00393          }
00394 
00395          // Was it a connect event?
00396          if (networkEvents.lNetworkEvents & FD_CONNECT)
00397          {
00398             // Did connect fail?
00399             if (networkEvents.iErrorCode[FD_CONNECT_BIT])
00400             {
00401                ::WSASetLastError(networkEvents.iErrorCode[FD_CONNECT_BIT]);
00402                _closeSocket(m_sockfd);
00403                BLOCXX_THROW(SocketException,
00404                   Format("SocketBaseImpl::connect() failed: %1(%2)",
00405                   ::WSAGetLastError(),
00406                   System::lastErrorMsg(true)).c_str());
00407             }
00408             break;
00409          }
00410       }  // while (true) - waiting for connection event
00411    }  // if SOCKET_ERROR on connect
00412 
00413    // Set socket back to blocking
00414    if(::WSAEventSelect(m_sockfd, m_event, 0) != 0)
00415    {
00416       _closeSocket(m_sockfd);
00417       BLOCXX_THROW(SocketException, 
00418          Format("Resetting socket with WSAEventSelect Failed: %1",
00419          System::lastErrorMsg(true)).c_str());
00420    }
00421    u_long ioctlarg = 0;
00422    ::ioctlsocket(m_sockfd, FIONBIO, &ioctlarg);
00423 
00424    m_isConnected = true;
00425 
00426    m_peerAddress = addr; // To get the hostname from addr
00427 
00428    BLOCXX_ASSERT(addr.getType() == SocketAddress::INET);
00429 
00430    fillInetAddrParms();
00431 }
00432 
00434 void
00435 SocketBaseImpl::disconnect()
00436 {
00437    if(m_in)
00438    {
00439       m_in.clear(ios::eofbit);
00440    }
00441    if(m_out)
00442    {
00443       m_out.clear(ios::eofbit);
00444    }
00445    if(!m_inout.fail())
00446    {
00447       m_inout.clear(ios::eofbit);
00448    }
00449 
00450    ::SetEvent(m_event);
00451    _closeSocket(m_sockfd);
00452    m_isConnected = false;
00453 }
00454 
00456 void
00457 SocketBaseImpl::fillInetAddrParms()
00458 {
00459    socklen_t len;
00460    InetSocketAddress_t addr;
00461    ::memset(&addr, 0, sizeof(addr));
00462    len = sizeof(addr);
00463    bool gotAddr = false;
00464 
00465    if (m_sockfd != INVALID_SOCKET)
00466    {
00467       len = sizeof(addr);
00468       if (::getsockname(m_sockfd,
00469          reinterpret_cast<struct sockaddr*>(&addr), &len) != SOCKET_ERROR)
00470       {
00471          m_localAddress.assignFromNativeForm(&addr, len);
00472       }
00473       else if (getAddrFromIface(addr) == 0)
00474       {
00475          len = sizeof(addr);
00476          m_localAddress.assignFromNativeForm(&addr, len);
00477       }
00478 
00479       len = sizeof(addr);
00480       if (::getpeername(m_sockfd, reinterpret_cast<struct sockaddr*>(&addr),
00481          &len) != SOCKET_ERROR)
00482       {
00483          m_peerAddress.assignFromNativeForm(&addr, len);
00484       }
00485    }
00486    else if (getAddrFromIface(addr) == 0)
00487    {
00488       m_localAddress.assignFromNativeForm(&addr, len);
00489    }
00490 }
00491 
00492 static Mutex guard;
00494 int
00495 SocketBaseImpl::write(const void* dataOut, int dataOutLen, ErrorAction errorAsException)
00496 {
00497    int rc = 0;
00498    bool isError = false;
00499    if (m_isConnected)
00500    {
00501       isError = waitForOutput(m_sendTimeout);
00502       if (isError)
00503       {
00504          rc = -1;
00505       }
00506       else
00507       {
00508          rc = writeAux(dataOut, dataOutLen);
00509          if (!m_traceFileOut.empty() && rc > 0)
00510          {
00511             MutexLock ml(guard);
00512             ofstream traceFile(m_traceFileOut.c_str(), std::ios::app);
00513             if (!traceFile)
00514             {
00515                BLOCXX_THROW(IOException, "Failed opening socket dump file");
00516             }
00517             if (!traceFile.write(static_cast<const char*>(dataOut), rc))
00518             {
00519                BLOCXX_THROW(IOException, "Failed writing to socket dump");
00520             }
00521 
00522             ofstream comboTraceFile(String(m_traceFileOut + "Combo").c_str(), std::ios::app);
00523             if (!comboTraceFile)
00524             {
00525                BLOCXX_THROW(IOException, "Failed opening socket dump file");
00526             }
00527             comboTraceFile << "\n--->Out " << rc << " bytes<---\n";
00528             if (!comboTraceFile.write(static_cast<const char*>(dataOut), rc))
00529             {
00530                BLOCXX_THROW(IOException, "Failed writing to socket dump");
00531             }
00532          }
00533       }
00534    }
00535    else
00536    {
00537       rc = -1;
00538    }
00539    if (rc < 0 && errorAsException)
00540    {
00541       BLOCXX_THROW(SocketException, "SocketBaseImpl::write");
00542    }
00543    return rc;
00544 }
00546 int
00547 SocketBaseImpl::read(void* dataIn, int dataInLen, ErrorAction errorAsException)  
00548 {
00549    int rc = 0;
00550    bool isError = false;
00551    if (m_isConnected)
00552    {
00553       isError = waitForInput(m_recvTimeout);
00554       if (isError)
00555       {
00556          rc = -1;
00557       }
00558       else
00559       {
00560          rc = readAux(dataIn, dataInLen);
00561          if (!m_traceFileIn.empty() && rc > 0)
00562          {
00563             MutexLock ml(guard);
00564             ofstream traceFile(m_traceFileIn.c_str(), std::ios::app);
00565             if (!traceFile)
00566             {
00567                BLOCXX_THROW(IOException, "Failed opening tracefile");
00568             }
00569             if (!traceFile.write(reinterpret_cast<const char*>(dataIn), rc))
00570             {
00571                BLOCXX_THROW(IOException, "Failed writing to socket dump");
00572             }
00573 
00574             ofstream comboTraceFile(String(m_traceFileOut + "Combo").c_str(), std::ios::app);
00575             if (!comboTraceFile)
00576             {
00577                BLOCXX_THROW(IOException, "Failed opening socket dump file");
00578             }
00579             comboTraceFile << "\n--->In " << rc << " bytes<---\n";
00580             if (!comboTraceFile.write(reinterpret_cast<const char*>(dataIn), rc))
00581             {
00582                BLOCXX_THROW(IOException, "Failed writing to socket dump");
00583             }
00584          }
00585       }
00586    }
00587    else
00588    {
00589       rc = -1;
00590    }
00591    if (rc < 0)
00592    {
00593       if (errorAsException)
00594          BLOCXX_THROW(SocketException, "SocketBaseImpl::read");
00595    }
00596    return rc;
00597 }
00599 bool
00600 SocketBaseImpl::waitForInput(const Timeout& timeOutSecs)
00601 {
00602    int rval = SocketUtils::waitForIO(m_sockfd, m_event, timeOutSecs, FD_READ);
00603    if (rval == ETIMEDOUT)
00604    {
00605       m_recvTimeoutExprd = true;
00606    }
00607    else
00608    {
00609       m_recvTimeoutExprd = false;
00610    }
00611    return (rval != 0);
00612 }
00614 bool
00615 SocketBaseImpl::waitForOutput(const Timeout& timeOutSecs)
00616 {
00617    return SocketUtils::waitForIO(m_sockfd, m_event, timeOutSecs,
00618       FD_WRITE) != 0;
00619 }
00621 istream&
00622 SocketBaseImpl::getInputStream()
00623 {
00624    return m_in;
00625 }
00627 ostream&
00628 SocketBaseImpl::getOutputStream()
00629 {
00630    return m_out;
00631 }
00633 iostream&
00634 SocketBaseImpl::getIOStream()
00635 {
00636    return m_inout;
00637 }
00639 // STATIC
00640 void
00641 SocketBaseImpl::setDumpFiles(const String& in, const String& out)
00642 {
00643    m_traceFileOut = out;
00644    m_traceFileIn = in;
00645 }
00646 
00647 } // end namespace BLOCXX_NAMESPACE
00648 
00649 #endif   // #if defined(BLOCXX_WIN32)