/*
 * $Id: dnsbounce.c,v 1.6 2001/02/03 22:52:37 sra Exp $
 *
 * Bounce DNS queries from a border router/nat/whatever to an internal
 * server, using the DNS header "ID" field to match up the traffic.
 *
 * This is very simple-minded, we don't attempt to grok the DNS protocol
 * itself, or even the Q/R bit, we just assume that all traffic from
 * our server's <address,port> pair is a response, and all else a query.
 *
 * This program is hereby explictly placed in the public domain as
 * Beer-Ware.  If we meet some day and you think this program is worth
 * it, you can buy me a beer.  Your mileage may vary.  We decline
 * responsibilities, all shapes, all sizes, all colors.  If this
 * program breaks, you get to keep both pieces.
 */

#include <stdio.h>
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <arpa/nameser.h>
#include <sys/time.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdarg.h>
#include <syslog.h>
#include <string.h>

/*
 * State block for an outstanding query.
 */
typedef struct query {
  struct query *link;
  u_int16_t id, my_id;
  struct sockaddr sa;
  time_t tstamp;
} query_t;

/*
 * argv[] layout.
 */
enum { arg_jane, arg_host, arg_remote_port, arg_local_port, arg_max };

/*
 * Maximum lifetime of an outstanding query (seconds).
 */
#ifndef Q_TIMEOUT
#define	Q_TIMEOUT	60
#endif

/*
 * Maximum number of queries outstanding at any given time.
 */
#ifndef	Q_MAXLEN
#define	Q_MAXLEN	500
#endif

/*
 * Service name of default UDP port.
 */
#ifndef	DEFAULT_PORT
#define	DEFAULT_PORT	"domain"
#endif

/*
 * How long to stall after a recvfrom() failure (seconds).
 */
#ifndef	STALL_INTERVAL
#define	STALL_INTERVAL	1
#endif

/*
 * Parameters for openlog().
 */
#ifndef	SYSLOG_OPT
#define	SYSLOG_OPT	(LOG_CONS | LOG_PID)
#endif
#ifndef	SYSLOG_FACILITY
#define	SYSLOG_FACILITY	LOG_DAEMON
#endif

static void fatal (int code, int priority, char *format, ...)
{
  va_list ap;
  va_start(ap, format);
  vsyslog(priority, format, ap);
  va_end(ap);
  exit(code);
}

static int get_udp_port(const char *name)
{
  struct servent *sp;
  int port;

  if ((sp = getservbyname(name, "udp")) != 0)
    return sp->s_port;

  if ((port = atoi(name)) != 0)
    return htons(port);

  fatal(1, LOG_ERR, "Couldn't parse UDP port \"%s\"", name);
  return 0;			/* keep gcc -Wall happy */
}

void main (int argc, char *argv[])
{
  struct sockaddr_in sin;
  struct sockaddr sa;
  struct hostent *hp;
  query_t *q_head = 0, **q_tail = &q_head;
  unsigned q_length = 0;
  unsigned char buffer[PACKETSZ];
  ssize_t buflen;
  u_int16_t id = 0;
  char *jane;
  query_t *q;
  int sa_len;
  time_t now;
  HEADER *h;
  int s;

  if ((jane = strrchr(argv[arg_jane], '/')) != 0)
    jane++;
  else
    jane = argv[arg_jane];

  openlog(jane, SYSLOG_OPT, SYSLOG_FACILITY);

#if 0
  if (daemon(0, 0) < 0)
    syslog(LOG_WARNING, "daemon() call failed: %m");
#endif

  if (argc <= arg_host || argc > arg_max)
    fatal(2, LOG_ERR, "usage: %s host [remote_port [local_port]]", jane);

  if ((s = socket(PF_INET, SOCK_DGRAM, 0)) < 0)
    fatal(3, LOG_ERR, "socket() call failed: %m");

  /*
   * Bind local side of our socket
   */

  memset(&sin, 0, sizeof(sin));
  sin.sin_len = sizeof(sin);
  sin.sin_family = AF_INET;
  sin.sin_addr.s_addr = INADDR_ANY;
  sin.sin_port = get_udp_port((argc > arg_local_port)
			      ? argv[arg_local_port]
			      : DEFAULT_PORT);

  if (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0)
    fatal(4, LOG_ERR, "bind() call failed: %m");

  /*
   * Set up sin to point to our server.
   */

  if ((hp = gethostbyname(argv[arg_host])) != 0)
    memcpy(&sin.sin_addr, hp->h_addr_list[0], sizeof(hp->h_addr_list[0]));
  else if (!inet_aton(argv[arg_host], &sin.sin_addr))
    fatal(5, LOG_ERR, "Couldn't find host \"%s\"");

  sin.sin_port = get_udp_port((argc > arg_remote_port)
			      ? argv[arg_remote_port]
			      : DEFAULT_PORT);

  /*
   * Loop forever processing packets.
   */

  for (;;) {
    /*
     * Try to get a packet.  This is the only place where we expect to block.
     */
    sa_len = sizeof(sa);
    buflen = recvfrom(s, buffer, sizeof(buffer), 0, &sa, &sa_len);

    /*
     * Get rid of any state blocks that have expired since the last
     * time we were awake.  We could do more to protect against denial
     * of service attacks, but it'd be more work and it's not clear
     * how much it'd gain us.  If you're running this on a 386SX40
     * PicoBSD machine with 4MB RAM, you might want to consider making
     * Q_TIMEOUT or Q_MAXLEN shorter.
     */
    time(&now);
    while (q_head && (now - q_head->tstamp) > Q_TIMEOUT) {
      q = q_head;
      if ((q_head = q->link) == 0)
	q_tail = &q_head;
      free(q);
      q_length--;
    }

    /*
     * See whether we got anything.  If not, stall briefly.
     */
    if (buflen < 0) {
      syslog(LOG_WARNING, "recvfrom() failed: %m");
      sleep(STALL_INTERVAL);
      continue;
    }

    h = (HEADER *) buffer;
    if (memcmp(&sa, &sin, sa_len)) {
      /*
       * Packet is not from our server, ie, it's a new request.
       * Allocate a state block and forward packet to our server.
       */
      if (q_length > Q_MAXLEN) {
	syslog(LOG_WARNING, "too many queries outstanding, dropping packet");
	continue;
      }
      if ((q = malloc(sizeof(*q))) == 0) {
	syslog(LOG_WARNING, "malloc() failed, dropping packet");
	continue;
      }
      memset(q, 0, sizeof(*q));
      q->id = h->id;
      q->my_id = ++id;
      q->sa = sa;
      q->tstamp = now;
      q->link = 0;
      *q_tail = q;
      q_tail = &q->link;
      q_length++;
      h->id = id;
      if (sendto(s, buffer, buflen, 0, (struct sockaddr *) &sin, sizeof(sin)) < 0)
	syslog(LOG_WARNING, "sendto() failed (query): %m");
    } else {
      /*
       * Packet is from our server, ie, it's a reply.  Find the state
       * block, forward the reply, and mark the state block as
       * expired.
       */
      for (q = q_head; q && q->my_id != h->id; q = q->link)
	;
      if (!q)
	continue;
      h->id = q->id;
      if (sendto(s, buffer, buflen, 0, &q->sa, sizeof(q->sa)) < 0)
	syslog(LOG_WARNING, "sendto() failed (response): %m");
      q->tstamp = now - Q_TIMEOUT - 1;
    }
  }
}
