//#
//# Multicast Class based on Java's java.net.MulticastSocket
//# $Revision: 1.1 $
//# Copyright 2004 by Eric Y. Theriault
//# All Rights Reserved.
//# http://www.eyt.ca/CORBA
//#
#include "MulticastSocket.h"
#include <exception>
#include <string>
#include <errno.h>
#include <string.h>
#include <stdio.h>
#include <sys/types.h>
#ifdef _WIN32
#include <ws2tcpip.h>
#else
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/time.h>
#include <arpa/inet.h>
#include <unistd.h>
#endif

#ifdef _WIN32
#define snprintf _snprintf
#endif

namespace {
#ifndef _WIN32
    const int INVALID_SOCKET = -1;
#endif

   //
   // Class for exceptions
   //
   class SocketException : public std::exception {
   protected:
      // Data
      std::string reason_;

   public:
      // Constructor
      explicit SocketException( const char * what ):
             reason_( what )
      {
      }

      // Destructor
      ~SocketException() throw ()
      {
      }

      // Reason
      virtual const char* what() const throw()
      {
         return reason_.c_str();
      }
   };

   //
   // Acquire the error message for the operating system.
   //
   std::string getErrorMessage()
   {
      char str[255];
#ifndef _WIN32
      snprintf( str, sizeof( str ), "%s", strerror( errno ) );
#else
      FormatMessage( FORMAT_MESSAGE_FROM_SYSTEM, NULL, WSAGetLastError(),
                     0, str, sizeof(str), 0 );
#endif
      return std::string( str );
   }
};

//
// Constructor
//
MulticastSocket::MulticastSocket( int port, bool bind ):
	port_( port ),
	timeout_( 0 ),
	socket_( INVALID_SOCKET )
{
   socket_ = socket( AF_INET, SOCK_DGRAM, 0 );
   if ( socket_ == INVALID_SOCKET ) {
      char str[255];
      snprintf( str, sizeof( str ), "eytUnable to create a socket: %s.",
                getErrorMessage().c_str() );
      throw SocketException( str );
   }

   // Binding is required for the server.
   if ( bind == true ) {
      // Set the reuse option
      int one = 1;
      if ( setsockopt( socket_, SOL_SOCKET, SO_REUSEADDR,
                       reinterpret_cast<const char *>( &one ), sizeof( one ) ) == -1 ) {
         char str[255];
         snprintf( str, sizeof( str ), "Unable to create a socket: %s.",
                   getErrorMessage().c_str() );
         throw SocketException( str );
      }

      // Bind
      struct sockaddr_in server;
      memset( &server, '\0', sizeof( server ) );
      server.sin_family = AF_INET;
      server.sin_addr.s_addr = htonl( INADDR_ANY );
      server.sin_port = htons( port );
      if ( ::bind( socket_, reinterpret_cast<const sockaddr*>( &server ), 
                   sizeof( server ) ) != 0 ) {
         char str[255];
         snprintf( str, sizeof( str ), "Unable to bind: %s.",
                   getErrorMessage().c_str() );
         throw SocketException( str );
      }
   }
}

//
// Destructor
//
MulticastSocket::~MulticastSocket()
{
   if ( socket_ != INVALID_SOCKET ) {
#ifdef _WIN32
       closesocket( socket_ );
#else
       close( socket_ );
#endif
   }
}

//
// Join
//
void MulticastSocket::joinGroup(
                                const char *address,
                                const char *iface
                               )
{
   struct ip_mreq mreq;

   // Assign the mcast address 
   struct sockaddr_in groupAddress;
   memset( &groupAddress, '\0', sizeof( groupAddress ) );
   groupAddress.sin_family = AF_INET;
   groupAddress.sin_addr.s_addr = inet_addr( address );
   memcpy( &mreq.imr_multiaddr, &groupAddress.sin_addr, sizeof( struct in_addr ) );
   if ( iface != 0 ) {
       mreq.imr_interface.s_addr = inet_addr( iface );
   }       
   else {  
       mreq.imr_interface.s_addr = htonl(INADDR_ANY);
   }       

   // Join
   if ( setsockopt( socket_, IPPROTO_IP, IP_ADD_MEMBERSHIP,
                    reinterpret_cast<const char *>( &mreq ),
                    sizeof( mreq ) ) != 0 ) {
      char str[255];
      snprintf( str, sizeof( str ), "Unable to join group %s: %s.",
                address, getErrorMessage().c_str() );
      throw SocketException( str );
   }
}

//
// Send a packet
//
void MulticastSocket::send(
                           const char *ipAddress,
                           const char *packet
                          )
{
   // Create the packet
   struct sockaddr_in address;
   memset( &address, '\0', sizeof( address ) );
   address.sin_family = AF_INET;
   address.sin_addr.s_addr = inet_addr( ipAddress );
   address.sin_port = htons( port_ );

   send( address, packet );
}

//
// Send
//
void MulticastSocket::send(
                           const struct sockaddr_in &address,
                           const char *packet
                          )
{
   // Send it
   if ( sendto( socket_, packet, static_cast<int>( strlen( packet ) ),
                0, reinterpret_cast<const struct sockaddr *>( &address ),
                sizeof( struct sockaddr_in ) ) == -1 ) {
      char str[255];
      snprintf( str, sizeof( str ), "Error sending packet: %s.",
                getErrorMessage().c_str() );
      throw SocketException( str );
   }
}

//
// Timeout
//
void MulticastSocket::setSoTimeout( int usec )
{
   timeout_ = usec;
}

//
// Receive a packet
//
std::string MulticastSocket::receive( struct sockaddr_in &address )
{
   // Create the FD set
   fd_set fd;
   FD_ZERO( &fd );
   FD_SET( socket_, &fd );

   // Initialize the timeout
   struct timeval timeout;
   memset( &timeout, '\0', sizeof( timeout ) );
   timeout.tv_usec = timeout_;

   // select.
   int retval = select( socket_ + 1, &fd, 0, 0, timeout_ ? &timeout : 0 );
   if ( retval == -1 ) {
      char str[255];
      snprintf( str, sizeof( str ), "Error receiving packet: %s.",
                getErrorMessage().c_str() );
      throw SocketException( str );
   }
   else if ( retval == 0 ) {
      throw SocketException( "Timeout" );
   }

   // invariant: Some data on socket.
   char buffer[1000];
   socklen_t len = sizeof( struct sockaddr_in );
   size_t n = recvfrom( socket_, &buffer[0], sizeof( buffer ), 0,
                        reinterpret_cast<struct sockaddr*>( &address ), 
                        &len );
   buffer[n] = '\0';
   return std::string( buffer );
}

