94da524243ee4fbd69afc96ae48c276e09607c06
[umurmur.git] / client.c
1 /* Copyright (C) 2009, Martin Johansson <martin@fatbob.nu>
2    Copyright (C) 2005-2009, Thorvald Natvig <thorvald@natvig.com>
3
4    All rights reserved.
5
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright notice,
11      this list of conditions and the following disclaimer.
12    - Redistributions in binary form must reproduce the above copyright notice,
13      this list of conditions and the following disclaimer in the documentation
14      and/or other materials provided with the distribution.
15    - Neither the name of the Developers nor the names of its contributors may
16      be used to endorse or promote products derived from this software without
17      specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31 #include <sys/poll.h>
32 #include <sys/socket.h>
33 #include <errno.h>
34 #include "log.h"
35 #include "list.h"
36 #include "client.h"
37 #include "ssl.h"
38 #include "messages.h"
39 #include "messagehandler.h"
40 #include "pds.h"
41 #include "conf.h"
42 #include "channel.h"
43
44
45
46 static int Client_read(client_t *client);
47 static int Client_write(client_t *client);
48 static int Client_voiceMsg(client_t *client, pds_t *pds);
49 static int Client_send_udp(client_t *client, uint8_t *data, int len);
50 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg);
51
52 declare_list(clients);
53 static int clientcount; /* = 0 */
54 static int session = 1;
55 static int maxBandwidth;
56
57 extern int udpsock;
58
59 void Client_init()
60 {
61         maxBandwidth = getIntConf(MAX_BANDWIDTH);
62 }
63
64 int Client_count()
65 {
66         return clientcount;
67 }
68
69 int Client_getfds(struct pollfd *pollfds)
70 {
71         struct dlist *itr;
72         int i = 0;
73         list_iterate(itr, &clients) {
74                 client_t *c;
75                 c = list_get_entry(itr, client_t, node);
76                 pollfds[i].fd = c->tcpfd;
77                 pollfds[i].events = POLLIN | POLLHUP | POLLERR;
78                 if (c->txsize > 0 || c->readBlockedOnWrite) /* Data waiting to be sent? */
79                         pollfds[i].events |= POLLOUT;
80                 i++;
81         }
82         return i;
83 }
84
85 void Client_janitor()
86 {
87         struct dlist *itr;
88         int bwTop = maxBandwidth + maxBandwidth / 4;
89         list_iterate(itr, &clients) {
90                 client_t *c;
91                 c = list_get_entry(itr, client_t, node);
92                 Log_debug("Client %s BW available %d", c->playerName, c->availableBandwidth);
93                 c->availableBandwidth += maxBandwidth;
94                 if (c->availableBandwidth > bwTop)
95                         c->availableBandwidth = bwTop;
96                 
97                 if (Timer_isElapsed(&c->lastActivity, 1000000LL * INACTICITY_TIMEOUT))
98                         Client_close(c);
99         }
100 }
101
102 int Client_add(int fd, struct sockaddr_in *remote)
103 {
104         client_t *newclient;
105
106         newclient = malloc(sizeof(client_t));
107         if (newclient == NULL)
108                 Log_fatal("Out of memory");
109         memset(newclient, 0, sizeof(client_t));
110
111         newclient->tcpfd = fd;
112         memcpy(&newclient->remote_tcp, remote, sizeof(struct sockaddr_in));
113         newclient->ssl = SSL_newconnection(newclient->tcpfd, &newclient->SSLready);
114         if (newclient->ssl == NULL) {
115                 Log_warn("SSL negotiation failed");
116                 free(newclient);
117                 return -1;
118         }
119         newclient->availableBandwidth = maxBandwidth;
120         Timer_init(&newclient->lastActivity);
121         newclient->sessionId = session++; /* XXX - more elaborate? */
122         
123         init_list_entry(&newclient->txMsgQueue);
124         init_list_entry(&newclient->chan_node);
125         init_list_entry(&newclient->node);
126         
127         list_add_tail(&newclient->node, &clients);
128         clientcount++;
129         return 0;
130 }
131
132 void Client_free(client_t *client)
133 {
134         struct dlist *itr, *save;
135         message_t *sendmsg;
136
137         Log_info("Disconnect client ID %d addr %s port %d", client->sessionId,
138                          inet_ntoa(client->remote_tcp.sin_addr),
139                          ntohs(client->remote_tcp.sin_port));
140
141         if (client->authenticated) {
142                 sendmsg = Msg_create(ServerLeave);
143                 sendmsg->sessionId = client->sessionId;
144                 Client_send_message_except(client, sendmsg);
145         }
146         list_iterate_safe(itr, save, &client->txMsgQueue) {
147                 list_del(&list_get_entry(itr, message_t, node)->node);
148                 Msg_free(list_get_entry(itr, message_t, node));
149         }
150                 
151         list_del(&client->node);
152         list_del(&client->chan_node);
153         if (client->ssl)
154                 SSL_free(client->ssl);
155         close(client->tcpfd);
156         clientcount--;
157         free(client);
158 }
159
160 void Client_close(client_t *client)
161 {
162         SSL_shutdown(client->ssl);
163         client->shutdown_wait = true;
164 }
165
166 void Client_disconnect_all()
167 {
168         struct dlist *itr, *save;
169         
170         list_iterate_safe(itr, save, &clients) {
171                 Client_free(list_get_entry(itr, client_t, node));
172         }
173 }
174
175 int Client_read_fd(int fd)
176 {
177         struct dlist *itr;
178         client_t *client = NULL;
179         
180         list_iterate(itr, &clients) {
181                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
182                         client = list_get_entry(itr, client_t, node);
183                         break;
184                 }
185         }
186         if (client == NULL)
187                 Log_fatal("No client found for fd %d", fd);
188         
189         return Client_read(client);
190 }
191
192 int Client_read(client_t *client)
193 {
194         int rc;
195
196         Timer_restart(&client->lastActivity);
197         
198         if (client->writeBlockedOnRead) {
199                 client->writeBlockedOnRead = false;
200                 Log_debug("Client_read: writeBlockedOnRead == true");
201                 return Client_write(client);
202         }
203         
204         if (client->shutdown_wait) {
205                 Client_free(client);
206                 return 0;
207         }
208         if (!client->SSLready) {
209                 int rc;
210                 rc = SSL_nonblockaccept(client->ssl, &client->SSLready);
211                 if (rc < 0) {
212                         Client_free(client);
213                         return -1;
214                 }
215         }
216
217         do {
218                 errno = 0;
219                 if (!client->msgsize) 
220                         rc = SSL_read(client->ssl, client->rxbuf, 3 - client->rxcount);
221                 else if (client->drainleft > 0)
222                         rc = SSL_read(client->ssl, client->rxbuf, client->drainleft > BUFSIZE ? BUFSIZE : client->drainleft);
223                 else
224                         rc = SSL_read(client->ssl, &client->rxbuf[client->rxcount], client->msgsize);
225                 if (rc > 0) {
226                         message_t *msg;
227                         if (client->drainleft > 0)
228                                 client->drainleft -= rc;
229                         else {
230                                 client->rxcount += rc;
231                                 if (!client->msgsize && rc >= 3)
232                                         client->msgsize = ((client->rxbuf[0] & 0xff) << 16) |
233                                                 ((client->rxbuf[1] & 0xff) << 8) |
234                                                 (client->rxbuf[2] & 0xff);
235                                 if (client->msgsize > BUFSIZE - 3 && client->drainleft == 0) {
236                                         Log_warn("Too big message received (%d). Discarding.", client->msgsize);
237                                         client->rxcount = client->msgsize = 0;
238                                         client->drainleft = client->msgsize;
239                                 }
240                                 else if (client->rxcount == client->msgsize + 3) { /* Got all of the message */
241                                         msg = Msg_networkToMessage(&client->rxbuf[3], client->msgsize);
242                                         /* pass messsage to handler */
243                                         if (msg) {
244                                                 if (msg->messageType == Speex) /* Tunneled voice message */
245                                                         Client_voiceMsg_tunnel(client, msg);
246                                                 else 
247                                                         Mh_handle_message(client, msg);
248                                         }
249                                         client->rxcount = client->msgsize = 0;
250                                 }
251                         }
252                 } else /* rc <= 0 */ {
253                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
254                                 return 0;
255                         }
256                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
257                                 client->readBlockedOnWrite = true;
258                                 return 0;
259                         }
260                         else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_ZERO_RETURN) {
261                                 Log_warn("Error: Zero return - closing");
262                                 if (!client->shutdown_wait)
263                                         Client_close(client);
264                         }
265                         else {
266                                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL) {
267                                         /* Hmm. This is where we end up when the client closes its connection.
268                                          * Kind of strange...
269                                          */
270                                         Log_info("Connection closed by peer");
271                                 }
272                                 else {
273                                         Log_warn("SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
274                                 }
275                                 Client_free(client);
276                                 return -1;
277                         }
278                 }
279         } while (SSL_pending(client->ssl));
280         return 0;       
281 }
282
283 int Client_write_fd(int fd)
284 {
285         struct dlist *itr;
286         client_t *client = NULL;
287         
288         list_iterate(itr, &clients) {
289                 if(fd == list_get_entry(itr, client_t, node)->tcpfd) {
290                         client = list_get_entry(itr, client_t, node);
291                         break;
292                 }
293         }
294         if (client == NULL)
295                 Log_fatal("No client found for fd %d", fd);
296         Client_write(client);
297         return 0;
298 }
299
300 int Client_write(client_t *client)
301 {
302         int rc;
303         
304         if (client->readBlockedOnWrite) {
305                 client->readBlockedOnWrite = false;
306                 Log_debug("Client_write: readBlockedOnWrite == true");
307                 return Client_read(client);
308         }
309         rc = SSL_write(client->ssl, &client->txbuf[client->txcount], client->txsize - client->txcount);
310         if (rc > 0) {
311                 client->txcount += rc;
312                 if (client->txcount == client->txsize)
313                         client->txsize = client->txcount = 0;
314         }
315         else if (rc < 0) {
316                 if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_READ) {
317                         client->writeBlockedOnRead = true;
318                         return 0;
319                 }
320                 else if (SSL_get_error(client->ssl, rc) == SSL_ERROR_WANT_WRITE) {
321                         return 0;
322                 }
323                 else {
324                         if (SSL_get_error(client->ssl, rc) == SSL_ERROR_SYSCALL)
325                                 Log_warn("Client_write: Error: %s  - Closing connection", strerror(errno));
326                         else
327                                 Log_warn("Client_write: SSL error: %d - Closing connection.", SSL_get_error(client->ssl, rc));
328                         Client_free(client);
329                         return -1;
330                 }
331         }
332         if (client->txsize == 0 && !list_empty(&client->txMsgQueue)) {
333                 message_t *msg;
334                 msg = list_get_entry(list_get_first(&client->txMsgQueue), message_t, node);
335                 list_del(list_get_first(&client->txMsgQueue));
336                 client->txQueueCount--;
337                 Client_send_message(client, msg);
338         }
339         return 0;
340 }
341
342 int Client_send_message(client_t *client, message_t *msg)
343 {
344         if (!client->authenticated || !client->SSLready) {
345                 Msg_free(msg);
346                 return 0;
347         }
348         if (client->txsize != 0) {
349                 /* Queue message */
350                 if ((client->txQueueCount > 5 &&  msg->messageType == Speex) ||
351                         client->txQueueCount > 30) {
352                         Msg_free(msg);
353                         return -1;
354                 }
355                 client->txQueueCount++;
356                 list_add_tail(&msg->node, &client->txMsgQueue);
357         } else {
358                 int len;
359                 memset(client->txbuf, 0, BUFSIZE);
360                 len = Msg_messageToNetwork(msg, &client->txbuf[3], BUFSIZE - 3);
361                 doAssert(len < BUFSIZE - 3);
362
363                 client->txbuf[0] =  (len >> 16) & 0xff;
364                 client->txbuf[1] =  (len >> 8) & 0xff;
365                 client->txbuf[2] =  len & 0xff;
366                 client->txsize = len + 3;
367                 client->txcount = 0;
368                 Client_write(client);
369                 Msg_free(msg);
370         }
371         return 0;
372 }
373
374 client_t *Client_iterate(client_t **client_itr)
375 {
376         client_t *c = *client_itr;
377         
378         if (c == NULL && !list_empty(&clients)) {
379                 c = list_get_entry(list_get_first(&clients), client_t, node);
380         } else {
381                 if (list_get_next(&c->node) == &clients)
382                         c = NULL;
383                 else
384                         c = list_get_entry(list_get_next(&c->node), client_t, node);
385         }
386         *client_itr = c;
387         return c;
388 }
389
390
391 int Client_send_message_except(client_t *client, message_t *msg)
392 {
393         client_t *itr = NULL;
394         int count = 0;
395         
396         Msg_inc_ref(msg); /* Make sure a reference is held during the whole iteration. */
397         while (Client_iterate(&itr) != NULL) {
398                 if (itr != client) {
399                         if (count++ > 0)
400                                 Msg_inc_ref(msg); /* One extra reference for each new copy */
401                         Log_debug("Msg %d to %s refcount %d",  msg->messageType, itr->playerName, msg->refcount);
402                         Client_send_message(itr, msg);
403                 }
404         }
405         Msg_free(msg); /* Free our reference to the message */
406         
407         if (count == 0)
408                 Msg_free(msg); /* If only 1 client is connected then no message is passed
409                                                 * to Client_send_message(). Free it here. */
410                 
411         return 0;
412 }
413
414 static bool_t checkDecrypt(client_t *client, const uint8_t *encrypted, uint8_t *plain, unsigned int len)
415 {
416         if (CryptState_isValid(&client->cryptState) &&
417                 CryptState_decrypt(&client->cryptState, encrypted, plain, len))
418                 return true;
419
420         if (Timer_elapsed(&client->cryptState.tLastGood) > 5000000ULL) {
421                 if (Timer_elapsed(&client->cryptState.tLastRequest) > 5000000ULL) {
422                         message_t *sendmsg;
423                         Timer_restart(&client->cryptState.tLastRequest);
424                         
425                         sendmsg = Msg_create(CryptSync);
426                         sendmsg->sessionId = client->sessionId;
427                         sendmsg->payload.cryptSync.empty = true;
428                         Log_info("Requesting voice channel crypt resync");
429                         Client_send_message(client, sendmsg);
430                 }
431         }
432         return false;
433 }
434
435 int Client_read_udp()
436 {
437         int len;
438         struct sockaddr_in from;
439         socklen_t fromlen = sizeof(struct sockaddr_in);
440         uint64_t key;
441         client_t *itr;
442         int msgType = 0;
443         uint32_t sessionId = 0;
444         pds_t *pds;
445         
446 #if defined(__LP64__)
447         uint8_t encbuff[512 + 8];
448         uint8_t *encrypted = encbuff + 4;
449 #else
450         uint8_t encrypted[512];
451 #endif
452         uint8_t buffer[512];
453         
454         len = recvfrom(udpsock, encrypted, 512, MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
455         if (len == 0) {
456                 return -1;
457         } else if (len < 0) {
458                 return -1;
459         } else if (len < 6) {
460                 // 4 bytes crypt header + type + session
461                 return 0;
462         } else if (len > 512) {
463                 return 0;
464         }
465         
466         key = (((uint64_t)from.sin_addr.s_addr) << 16) ^ from.sin_port;
467         pds = Pds_create(buffer, len - 4);
468         itr = NULL;
469         
470         while (Client_iterate(&itr) != NULL) {
471                 if (itr->key == key) {
472                         if (!checkDecrypt(itr, encrypted, buffer, len))
473                                 goto out;
474                         msgType = Pds_get_numval(pds);
475                         sessionId = Pds_get_numval(pds);
476                         if (itr->sessionId != sessionId)
477                                 goto out;
478                         break;
479                 }
480         }       
481         if (itr == NULL) { /* Unknown peer */
482                 while (Client_iterate(&itr) != NULL) {
483                         pds->offset = 0;
484                         if (itr->remote_tcp.sin_addr.s_addr == from.sin_addr.s_addr) {
485                                 if (checkDecrypt(itr, encrypted, buffer, len)) {
486                                         msgType = Pds_get_numval(pds);
487                                         sessionId = Pds_get_numval(pds);
488                                         if (itr->sessionId == sessionId) { /* Found matching client */
489                                                 itr->key = key;
490                                                 Log_info("New UDP connection from %s port %d sessionId %d", inet_ntoa(from.sin_addr), ntohs(from.sin_port), sessionId);
491                                                 memcpy(&itr->remote_udp, &from, sizeof(struct sockaddr_in));
492                                                 break;
493                                         }
494                                 }
495                                 else Log_warn("Bad cryptstate from peer");
496                         }
497                 } /* while */
498         }
499         if (itr == NULL) {
500                 goto out;
501         }
502         len -= 4;
503         if (msgType != Speex && msgType != Ping)
504                 goto out;
505         
506         if (msgType == Ping) {
507                 Client_send_udp(itr, buffer, len);
508         }
509         else {
510                 Client_voiceMsg(itr, pds);
511         }
512         
513 out:
514         Pds_free(pds);
515         return 0;
516 }
517
518 static void Client_voiceMsg_tunnel(client_t *client, message_t *msg)
519 {
520         uint8_t buf[512];
521         pds_t *pds = Pds_create(buf, 512);
522
523         Pds_add_numval(pds, msg->messageType);
524         Pds_add_numval(pds, msg->sessionId);
525         Pds_add_numval(pds, msg->payload.speex.seq);
526         Pds_append_data_nosize(pds, msg->payload.speex.data, msg->payload.speex.size);
527         if (!pds->bOk)
528                 Log_warn("Large Speex message from TCP"); /* XXX - pds resize? */
529         pds->maxsize = pds->offset;
530         Client_voiceMsg(client, pds);
531         Pds_free(pds);
532 }
533
534 static int Client_voiceMsg(client_t *client, pds_t *pds)
535 {
536         int seq, flags, msgType, sessionId, packetsize;
537         channel_t *ch = (channel_t *)client->channel;
538         struct dlist *itr;
539         
540         if (!client->authenticated || client->mute)
541                 return 0;
542
543         
544         pds->offset = 0;
545         msgType = Pds_get_numval(pds);
546         sessionId = Pds_get_numval(pds);
547         seq = Pds_get_numval(pds);
548         flags = Pds_get_numval(pds);
549
550         packetsize = 20 + 8 + 4 + pds->maxsize - pds->offset;
551         if (client->availableBandwidth - packetsize < 0)
552                 return 0; /* Discard */
553         
554         client->availableBandwidth -= packetsize;
555         
556         pds->offset = 0;
557         
558         if (flags & LoopBack) {
559                 Client_send_udp(client, pds->data, pds->maxsize);
560                 return 0;
561         }
562         if (ch == NULL)
563                 return 0;
564         
565         list_iterate(itr, &ch->clients) {
566                 client_t *c;
567                 c = list_get_entry(itr, client_t, chan_node);
568                 if (c != client && !c->deaf) {
569                         Client_send_udp(c, pds->data, pds->maxsize);
570                 }
571         }
572         return 0;
573 }
574
575
576 static int Client_send_udp(client_t *client, uint8_t *data, int len)
577 {
578         uint8_t *buf, *mbuf;
579         message_t *sendmsg;
580
581         if (client->remote_udp.sin_port != 0 && CryptState_isValid(&client->cryptState)) {
582 #if defined(__LP64__)
583                 buf = mbuf = malloc(len + 4 + 16);
584                 buf += 4;
585 #else
586                 mbuf = buf = malloc(len + 4);
587 #endif
588                 if (mbuf == NULL)
589                         Log_fatal("Out of memory");
590                 
591                 CryptState_encrypt(&client->cryptState, data, buf, len);
592                 
593                 sendto(udpsock, buf, len + 4, 0, (struct sockaddr *)&client->remote_udp, sizeof(struct sockaddr_in));
594                 
595                 free(mbuf);
596         } else {
597                 pds_t *pds = Pds_create(data, len);
598                 
599                 sendmsg = Msg_create(Pds_get_numval(pds));
600                 sendmsg->sessionId = Pds_get_numval(pds);
601                 
602                 if (sendmsg->messageType == Speex || sendmsg->messageType == Ping) {
603                         if (sendmsg->messageType == Speex) {
604                                 sendmsg->payload.speex.seq = Pds_get_numval(pds);
605                                 sendmsg->payload.speex.size = pds->maxsize - pds->offset;
606                                 doAssert(pds->maxsize - pds->offset <= SPEEX_DATA_SIZE);
607                                 memcpy(sendmsg->payload.speex.data, data + pds->offset, pds->maxsize - pds->offset);
608                         } else { /* Ping */
609                                 sendmsg->payload.ping.timestamp = Pds_get_numval(pds);
610                         }
611                         Client_send_message(client, sendmsg);
612                 } else {
613                         Log_warn("TCP fallback: Unsupported message type %d", sendmsg->messageType);
614                         Msg_free(sendmsg);
615                 }
616                 Pds_free(pds);
617         }
618         return 0;
619 }