bring changes over from dev branch

old tests pass without any socket leakage
This commit is contained in:
Eric House 2017-12-10 08:21:22 -08:00
parent 2e71aedc02
commit fc4e577f1b
7 changed files with 170 additions and 77 deletions

View file

@ -20,13 +20,16 @@
*/
#include <assert.h>
#include <errno.h>
#include <string.h>
#include <time.h>
#include <unistd.h>
#include "addrinfo.h"
#include "xwrelay_priv.h"
#include "tpool.h"
#include "udpager.h"
#include "mlock.h"
// static uint32_t s_prevCreated = 0L;
@ -68,7 +71,7 @@ AddrInfo::equals( const AddrInfo& other ) const
if ( isTCP() ) {
equal = m_socket == other.m_socket;
if ( equal && created() != other.created() ) {
logf( XW_LOGINFO, "%s: rejecting on time mismatch (%lx vs %lx)",
logf( XW_LOGINFO, "%s(): rejecting on time mismatch (%lx vs %lx)",
__func__, created(), other.created() );
equal = false;
}
@ -82,3 +85,40 @@ AddrInfo::equals( const AddrInfo& other ) const
return equal;
}
static pthread_mutex_t s_refMutex = PTHREAD_MUTEX_INITIALIZER;
static map<int, int > s_socketRefs;
void AddrInfo::ref() const
{
// logf( XW_LOGVERBOSE0, "%s(socket=%d)", __func__, m_socket );
MutexLock ml( &s_refMutex );
++s_socketRefs[m_socket];
printRefMap();
}
void
AddrInfo::unref() const
{
// logf( XW_LOGVERBOSE0, "%s(socket=%d)", __func__, m_socket );
MutexLock ml( &s_refMutex );
assert( s_socketRefs[m_socket] > 0 );
--s_socketRefs[m_socket];
if ( s_socketRefs[m_socket] == 0 ) {
XWThreadPool::GetTPool()->CloseSocket( this );
}
printRefMap();
}
/* private, and assumes have mutex */
void
AddrInfo::printRefMap() const
{
/* for ( map<int,int>::const_iterator iter = s_socketRefs.begin(); */
/* iter != s_socketRefs.end(); ++iter ) { */
/* int count = iter->second; */
/* if ( count > 0 ) { */
/* logf( XW_LOGVERBOSE0, "socket: %d; count: %d", iter->first, count ); */
/* } */
/* } */
}

View file

@ -81,12 +81,18 @@ class AddrInfo {
bool equals( const AddrInfo& other ) const;
/* refcount the underlying socket (doesn't modify instance) */
void ref() const;
void unref() const;
int getref() const;
private:
void construct( int sock, const AddrUnion* saddr, bool isTCP );
void init( int sock, ClientToken clientToken, const AddrUnion* saddr ) {
construct( sock, saddr, false );
m_clientToken = clientToken;
}
void printRefMap() const;
// AddrInfo& operator=(const AddrInfo&); // Prevent assignment
int m_socket;

View file

@ -84,12 +84,13 @@ RelayConfigs::GetValueFor( const char* key, time_t* value )
bool
RelayConfigs::GetValueFor( const char* key, char* buf, int len )
{
MutexLock ml( &m_values_mutex );
pthread_mutex_lock( &m_values_mutex );
map<const char*,const char*>::const_iterator iter = m_values.find(key);
bool found = iter != m_values.end();
if ( found ) {
snprintf( buf, len, "%s", iter->second );
}
pthread_mutex_unlock( &m_values_mutex );
return found;
}
@ -125,7 +126,7 @@ RelayConfigs::GetValueFor( const char* key, vector<int>& ints )
void
RelayConfigs::SetValueFor( const char* key, const char* value )
{
MutexLock ml( &m_values_mutex );
pthread_mutex_lock( &m_values_mutex );
/* Remove any entry already there */
map<const char*,const char*>::iterator iter = m_values.find(key);
@ -136,6 +137,7 @@ RelayConfigs::SetValueFor( const char* key, const char* value )
pair<map<const char*,const char*>::iterator,bool> result =
m_values.insert( pair<const char*,const char*>(strdup(key),strdup(value) ) );
assert( result.second );
pthread_mutex_unlock( &m_values_mutex );
}
ino_t

View file

@ -119,13 +119,17 @@ XWThreadPool::Stop()
void
XWThreadPool::AddSocket( SockType stype, QueueCallback proc, const AddrInfo* from )
{
from->ref();
int sock = from->getSocket();
logf( XW_LOGVERBOSE0, "%s(sock=%d, isTCP=%d)", __func__, sock, from->isTCP() );
SockInfo si = { .m_type = stype,
.m_proc = proc,
.m_addr = *from
};
{
int sock = from->getSocket();
RWWriteLock ml( &m_activeSocketsRWLock );
SockInfo si;
si.m_type = stype;
si.m_proc = proc;
si.m_addr = *from;
assert( m_activeSockets.find( sock ) == m_activeSockets.end() );
m_activeSockets.insert( pair<int, SockInfo>( sock, si ) );
}
interrupt_poll();
@ -158,13 +162,14 @@ XWThreadPool::RemoveSocket( const AddrInfo* addr )
size_t prevSize = m_activeSockets.size();
map<int, SockInfo>::iterator iter = m_activeSockets.find( addr->getSocket() );
int sock = addr->getSocket();
map<int, SockInfo>::iterator iter = m_activeSockets.find( sock );
if ( m_activeSockets.end() != iter && iter->second.m_addr.equals( *addr ) ) {
m_activeSockets.erase( iter );
found = true;
}
logf( XW_LOGINFO, "%s: AFTER: %d sockets active (was %d)", __func__,
m_activeSockets.size(), prevSize );
logf( XW_LOGINFO, "%s(): AFTER closing %d: %d sockets active (was %d)", __func__,
sock, m_activeSockets.size(), prevSize );
}
return found;
} /* RemoveSocket */
@ -184,8 +189,14 @@ XWThreadPool::CloseSocket( const AddrInfo* addr )
++iter;
}
}
logf( XW_LOGINFO, "CLOSING socket %d", addr->getSocket() );
close( addr->getSocket() );
int sock = addr->getSocket();
int err = close( sock );
if ( 0 != err ) {
logf( XW_LOGERROR, "%s(): close(socket=%d) => %d/%s", __func__,
sock, errno, strerror(errno) );
} else {
logf( XW_LOGINFO, "%s(): close(socket=%d) succeeded", __func__, sock );
}
/* We always need to interrupt the poll because the socket we're closing
will be in the list being listened to. That or we need to drop sockets
@ -198,7 +209,7 @@ XWThreadPool::CloseSocket( const AddrInfo* addr )
void
XWThreadPool::EnqueueKill( const AddrInfo* addr, const char* const why )
{
logf( XW_LOGINFO, "%s(%d) reason: %s", __func__, addr->getSocket(), why );
logf( XW_LOGINFO, "%s(socket = %d) reason: %s", __func__, addr->getSocket(), why );
if ( addr->isTCP() ) {
SockInfo si;
si.m_type = STYPE_UNKNOWN;
@ -265,7 +276,6 @@ XWThreadPool::real_tpool_main( ThreadInfo* tip )
if ( gotOne ) {
sock = pr.m_info.m_addr.getSocket();
logf( XW_LOGINFO, "worker thread got socket %d from queue", socket );
switch ( pr.m_act ) {
case Q_READ:
assert( 0 );
@ -275,8 +285,9 @@ XWThreadPool::real_tpool_main( ThreadInfo* tip )
// }
break;
case Q_KILL:
logf( XW_LOGINFO, "worker thread got socket %d from queue (to close it)", sock );
(*m_kFunc)( &pr.m_info.m_addr );
CloseSocket( &pr.m_info.m_addr );
pr.m_info.m_addr.unref();
break;
}
} else {
@ -392,35 +403,40 @@ XWThreadPool::real_listener()
curfd = 1;
int ii;
for ( ii = 0; ii < nSockets && nEvents > 0; ++ii ) {
for ( ii = 0; ii < nSockets && nEvents > 0; ++ii, ++curfd ) {
if ( fds[curfd].revents != 0 ) {
// int socket = fds[curfd].fd;
SockInfo* sinfo = &sinfos[curfd];
const AddrInfo* addr = &sinfo->m_addr;
assert( fds[curfd].fd == addr->getSocket() );
int sock = addr->getSocket();
assert( fds[curfd].fd == sock );
if ( !SocketFound( addr ) ) {
logf( XW_LOGINFO, "%s(): dropping socket %d: not found",
__func__, addr->getSocket() );
/* no further processing if it's been removed while
we've been sleeping in poll */
we've been sleeping in poll. BUT: shouldn't curfd
be incremented?? */
--nEvents;
continue;
}
if ( 0 != (fds[curfd].revents & (POLLIN | POLLPRI)) ) {
if ( !UdpQueue::get()->handle( addr, sinfo->m_proc ) ) {
// This is likely wrong!!! return of 0 means
// remote closed, not error.
RemoveSocket( addr );
EnqueueKill( addr, "bad packet" );
EnqueueKill( addr, "got EOF" );
}
} else {
logf( XW_LOGERROR, "odd revents: %x",
fds[curfd].revents );
logf( XW_LOGERROR, "%s(): odd revents: %x; bad socket %d",
__func__, fds[curfd].revents, sock );
RemoveSocket( addr );
EnqueueKill( addr, "error/hup in poll()" );
}
--nEvents;
}
++curfd;
}
assert( nEvents == 0 );
}

View file

@ -28,7 +28,7 @@ static UdpQueue* s_instance = NULL;
void
UdpThreadClosure::logStats()
PacketThreadClosure::logStats()
{
time_t now = time( NULL );
if ( 1 < now - m_created ) {
@ -48,6 +48,7 @@ PartialPacket::stillGood() const
bool
PartialPacket::readAtMost( int len )
{
assert( len > 0 );
bool success = false;
uint8_t tmp[len];
ssize_t nRead = recv( m_sock, tmp, len, 0 );
@ -57,10 +58,12 @@ PartialPacket::readAtMost( int len )
logf( XW_LOGERROR, "%s(len=%d, socket=%d): recv failed: %d (%s)", __func__,
len, m_sock, m_errno, strerror(m_errno) );
}
} else if ( 0 == nRead ) { // remote socket closed
logf( XW_LOGINFO, "%s: remote closed (socket=%d)", __func__, m_sock );
} else if ( 0 == nRead ) { // remote socket half-closed
logf( XW_LOGINFO, "%s(): remote closed (socket=%d)", __func__, m_sock );
m_errno = -1; // so stillGood will fail
} else {
// logf( XW_LOGVERBOSE0, "%s(): read %d bytes on socket %d", __func__,
// nRead, m_sock );
m_errno = 0;
success = len == nRead;
int curSize = m_buf.size();
@ -100,7 +103,11 @@ UdpQueue::get()
return s_instance;
}
// return false if socket should no longer be used
// If we're already assembling data from this socket, continue. Otherwise
// create a new parital packet and store data there. If we wind up with a
// complete packet, dispatch it and delete since the data's been delivered.
//
// Return false if socket should no longer be used.
bool
UdpQueue::handle( const AddrInfo* addr, QueueCallback cb )
{
@ -145,6 +152,7 @@ UdpQueue::handle( const AddrInfo* addr, QueueCallback cb )
}
success = success && (NULL == packet || packet->stillGood());
logf( XW_LOGVERBOSE0, "%s(sock=%d) => %d", __func__, sock, success );
return success;
}
@ -152,17 +160,21 @@ void
UdpQueue::handle( const AddrInfo* addr, const uint8_t* buf, int len,
QueueCallback cb )
{
UdpThreadClosure* utc = new UdpThreadClosure( addr, buf, len, cb );
// addr->ref();
PacketThreadClosure* ptc = new PacketThreadClosure( addr, buf, len, cb );
MutexLock ml( &m_queueMutex );
int id = ++m_nextID;
utc->setID( id );
logf( XW_LOGINFO, "%s: enqueuing packet %d (socket %d, len %d)",
ptc->setID( id );
logf( XW_LOGINFO, "%s(): enqueuing packet %d (socket %d, len %d)",
__func__, id, addr->getSocket(), len );
m_queue.push_back( utc );
m_queue.push_back( ptc );
pthread_cond_signal( &m_queueCondVar );
}
// Remove any PartialPacket record with the same socket/fd. This makes sense
// when the socket's being reused or when we have just dealt with a single
// packet and might be getting more.
void
UdpQueue::newSocket_locked( int sock )
{
@ -194,25 +206,26 @@ UdpQueue::thread_main()
while ( m_queue.size() == 0 ) {
pthread_cond_wait( &m_queueCondVar, &m_queueMutex );
}
UdpThreadClosure* utc = m_queue.front();
PacketThreadClosure* ptc = m_queue.front();
m_queue.pop_front();
pthread_mutex_unlock( &m_queueMutex );
utc->noteDequeued();
ptc->noteDequeued();
time_t age = utc->ageInSeconds();
time_t age = ptc->ageInSeconds();
if ( 30 > age ) {
logf( XW_LOGINFO, "%s: dispatching packet %d (socket %d); "
"%d seconds old", __func__, utc->getID(),
utc->addr()->getSocket(), age );
(*utc->cb())( utc );
utc->logStats();
"%d seconds old", __func__, ptc->getID(),
ptc->addr()->getSocket(), age );
(*ptc->cb())( ptc );
ptc->logStats();
} else {
logf( XW_LOGINFO, "%s: dropping packet %d; it's %d seconds old!",
__func__, age );
}
delete utc;
// ptc->addr()->unref();
delete ptc;
}
return NULL;
}

View file

@ -30,13 +30,13 @@
using namespace std;
class UdpThreadClosure;
class PacketThreadClosure;
typedef void (*QueueCallback)( UdpThreadClosure* closure );
typedef void (*QueueCallback)( PacketThreadClosure* closure );
class UdpThreadClosure {
class PacketThreadClosure {
public:
UdpThreadClosure( const AddrInfo* addr, const uint8_t* buf,
PacketThreadClosure( const AddrInfo* addr, const uint8_t* buf,
int len, QueueCallback cb )
: m_buf(new uint8_t[len])
, m_len(len)
@ -45,9 +45,13 @@ public:
, m_created(time( NULL ))
{
memcpy( m_buf, buf, len );
m_addr.ref();
}
~UdpThreadClosure() { delete[] m_buf; }
~PacketThreadClosure() {
m_addr.unref();
delete[] m_buf;
}
const uint8_t* buf() const { return m_buf; }
int len() const { return m_len; }
@ -109,8 +113,8 @@ class UdpQueue {
pthread_mutex_t m_partialsMutex;
pthread_mutex_t m_queueMutex;
pthread_cond_t m_queueCondVar;
deque<UdpThreadClosure*> m_queue;
// map<int, vector<UdpThreadClosure*> > m_bySocket;
deque<PacketThreadClosure*> m_queue;
// map<int, vector<PacketThreadClosure*> > m_bySocket;
int m_nextID;
map<int, PartialPacket*> m_partialPackets;
};

View file

@ -124,8 +124,6 @@ logf( XW_LogLevel level, const char* format, ... )
va_end(ap);
#else
FILE* where = NULL;
struct tm* timp;
struct timeval tv;
bool useFile;
char logFile[256];
@ -143,13 +141,14 @@ logf( XW_LogLevel level, const char* format, ... )
if ( !!where ) {
static int tm_yday = 0;
struct timeval tv;
gettimeofday( &tv, NULL );
struct tm result;
timp = localtime_r( &tv.tv_sec, &result );
struct tm* timp = localtime_r( &tv.tv_sec, &result );
char timeBuf[64];
sprintf( timeBuf, "%.2d:%.2d:%.2d", timp->tm_hour,
timp->tm_min, timp->tm_sec );
sprintf( timeBuf, "%.2d:%.2d:%.2d.%03ld", timp->tm_hour,
timp->tm_min, timp->tm_sec, tv.tv_usec / 1000 );
/* log the date once/day. This isn't threadsafe so may be
repeated but that's harmless. */
@ -1031,7 +1030,7 @@ processDisconnect( const uint8_t* bufp, int bufLen, const AddrInfo* addr )
} /* processDisconnect */
static void
killSocket( const AddrInfo* addr )
rmSocketRefs( const AddrInfo* addr )
{
logf( XW_LOGINFO, "%s(addr.socket=%d)", __func__, addr->getSocket() );
CRefMgr::Get()->RemoveSocketRefs( addr );
@ -1304,14 +1303,17 @@ handleMsgsMsg( const AddrInfo* addr, bool sendFull,
const uint8_t* bufp, const uint8_t* end )
{
unsigned short nameCount;
int ii;
if ( getNetShort( &bufp, end, &nameCount ) ) {
assert( nameCount == 1 ); // Don't commit this!!!
DBMgr* dbmgr = DBMgr::Get();
vector<uint8_t> out(4); /* space for len and n_msgs */
assert( out.size() == 4 );
vector<int> msgIDs;
for ( ii = 0; ii < nameCount && bufp < end; ++ii ) {
for ( int ii = 0; ii < nameCount; ++ii ) {
if ( bufp >= end ) {
logf( XW_LOGERROR, "%s(): ran off the end", __func__ );
break;
}
// See NetUtils.java for reply format
// message-length: 2
// nameCount: 2
@ -1329,6 +1331,7 @@ handleMsgsMsg( const AddrInfo* addr, bool sendFull,
break;
}
logf( XW_LOGVERBOSE0, "%s(): connName: %s", __func__, connName );
dbmgr->RecordAddress( connName, hid, addr );
/* For each relayID, write the number of messages and then
@ -1345,14 +1348,21 @@ handleMsgsMsg( const AddrInfo* addr, bool sendFull,
memcpy( &out[0], &tmp, sizeof(tmp) );
tmp = htons( nameCount );
memcpy( &out[2], &tmp, sizeof(tmp) );
ssize_t nwritten = write( addr->getSocket(), &out[0], out.size() );
logf( XW_LOGVERBOSE0, "%s: wrote %d bytes", __func__, nwritten );
if ( sendFull && nwritten >= 0 && (size_t)nwritten == out.size() ) {
int sock = addr->getSocket();
ssize_t nWritten = write( sock, &out[0], out.size() );
if ( nWritten < 0 ) {
logf( XW_LOGERROR, "%s(): write to socket %d failed: %d/%s", __func__,
sock, errno, strerror(errno) );
} else if ( sendFull && (size_t)nWritten == out.size() ) {
logf( XW_LOGVERBOSE0, "%s(): wrote %d bytes to socket %d", __func__,
nWritten, sock );
dbmgr->RecordSent( &msgIDs[0], msgIDs.size() );
// This is wrong: should be removed when ACK returns and not
// before. But for some reason if I make that change apps wind up
// stalling.
dbmgr->RemoveStoredMessages( msgIDs );
} else {
assert(0);
}
}
} // handleMsgsMsg
@ -1476,23 +1486,24 @@ handleProxyMsgs( int sock, const AddrInfo* addr, const uint8_t* bufp,
} // handleProxyMsgs
static void
game_thread_proc( UdpThreadClosure* utc )
game_thread_proc( PacketThreadClosure* ptc )
{
if ( !processMessage( utc->buf(), utc->len(), utc->addr(), 0 ) ) {
XWThreadPool::GetTPool()->CloseSocket( utc->addr() );
logf( XW_LOGVERBOSE0, "%s()", __func__ );
if ( !processMessage( ptc->buf(), ptc->len(), ptc->addr(), 0 ) ) {
// XWThreadPool::GetTPool()->CloseSocket( ptc->addr() );
}
}
static void
proxy_thread_proc( UdpThreadClosure* utc )
proxy_thread_proc( PacketThreadClosure* ptc )
{
const int len = utc->len();
const AddrInfo* addr = utc->addr();
const int len = ptc->len();
const AddrInfo* addr = ptc->addr();
if ( len > 0 ) {
assert( addr->isTCP() );
int sock = addr->getSocket();
const uint8_t* bufp = utc->buf();
const uint8_t* bufp = ptc->buf();
const uint8_t* end = bufp + len;
if ( (0 == *bufp++) ) { /* protocol */
XWPRXYCMD cmd = (XWPRXYCMD)*bufp++;
@ -1561,7 +1572,8 @@ proxy_thread_proc( UdpThreadClosure* utc )
}
}
}
XWThreadPool::GetTPool()->CloseSocket( addr );
// Should I remove this, or make it into more of an unref() call?
// XWThreadPool::GetTPool()->CloseSocket( addr );
} // proxy_thread_proc
static size_t
@ -1726,10 +1738,10 @@ ackPacketIf( const UDPHeader* header, const AddrInfo* addr )
}
static void
handle_udp_packet( UdpThreadClosure* utc )
handle_udp_packet( PacketThreadClosure* ptc )
{
const uint8_t* ptr = utc->buf();
const uint8_t* end = ptr + utc->len();
const uint8_t* ptr = ptc->buf();
const uint8_t* end = ptr + ptc->len();
UDPHeader header;
if ( getHeader( &ptr, end, &header ) ) {
@ -1752,7 +1764,7 @@ handle_udp_packet( UdpThreadClosure* utc )
if ( 3 >= clientVers ) {
checkAllAscii( model, "bad model" );
}
registerDevice( relayID, &devID, utc->addr(),
registerDevice( relayID, &devID, ptc->addr(),
clientVers, devDesc, model, osVers );
}
}
@ -1765,7 +1777,7 @@ handle_udp_packet( UdpThreadClosure* utc )
ptr += sizeof(clientToken);
clientToken = ntohl( clientToken );
if ( AddrInfo::NULL_TOKEN != clientToken ) {
AddrInfo addr( g_udpsock, clientToken, utc->saddr() );
AddrInfo addr( g_udpsock, clientToken, ptc->saddr() );
(void)processMessage( ptr, end - ptr, &addr, clientToken );
} else {
logf( XW_LOGERROR, "%s: dropping packet with token of 0",
@ -1786,7 +1798,7 @@ handle_udp_packet( UdpThreadClosure* utc )
}
SafeCref scr( connName, hid );
if ( scr.IsValid() ) {
AddrInfo addr( g_udpsock, clientToken, utc->saddr() );
AddrInfo addr( g_udpsock, clientToken, ptc->saddr() );
handlePutMessage( scr, hid, &addr, end - ptr, &ptr, end );
assert( ptr == end ); // DON'T CHECK THIS IN!!!
} else {
@ -1821,7 +1833,7 @@ handle_udp_packet( UdpThreadClosure* utc )
case XWPDEV_RQSTMSGS: {
DevID devID( ID_TYPE_RELAY );
if ( getVLIString( &ptr, end, devID.m_devIDString ) ) {
const AddrInfo* addr = utc->addr();
const AddrInfo* addr = ptc->addr();
DevMgr::Get()->rememberDevice( devID.asRelayID(), addr );
if ( XWPDEV_RQSTMSGS == header.cmd ) {
@ -1862,7 +1874,7 @@ handle_udp_packet( UdpThreadClosure* utc )
}
// Do this after the device and address are registered
ackPacketIf( &header, utc->addr() );
ackPacketIf( &header, ptc->addr() );
}
}
@ -2335,7 +2347,7 @@ main( int argc, char** argv )
(void)sigaction( SIGINT, &act, NULL );
XWThreadPool* tPool = XWThreadPool::GetTPool();
tPool->Setup( nWorkerThreads, killSocket );
tPool->Setup( nWorkerThreads, rmSocketRefs );
/* set up select call */
fd_set rfds;