/* $Id: hmac.c,v 1.4 2004/07/18 22:30:02 sra Exp $
 *
 * Hashed Message Authentication Code tool, based on the appendix to
 * RFC 2104.  This implementation uses the FreeBSD libmd functions,
 * which provide an (almost) identical interface to the MD5, SHA-1,
 * and RIPEMD160 cryptographic hash functions.
 */

#include <stdio.h>
#include <assert.h>
#include <sys/types.h>

#define HMAC_KEY_LENGTH		64
typedef unsigned char hmac_key_t[HMAC_KEY_LENGTH + 1];


/*
 * Pick which hash algorithms we're going to implement.
 */

#ifndef INSTALL_MD5
#define INSTALL_MD5		1
#endif

#ifndef INSTALL_SHA1
#define INSTALL_SHA1		1
#endif

#ifndef INSTALL_RIPEMD160
#define INSTALL_RIPEMD160	1
#endif

/*
 * Include definitions for the hash algorithms we picked.
 */

#if INSTALL_MD5
#include <md5.h>
#endif

#if INSTALL_SHA1
#include <sha.h>
#endif

#if INSTALL_RIPEMD160
#include <ripemd.h>
#endif

/*
 * Define generic context and driver structures to hide the details
 * that depend on the specific choice of hash algorithm from the main
 * program.
 */

typedef union {
#if INSTALL_MD5
  MD5_CTX md5_ctx;
#endif
#if INSTALL_SHA1
  SHA_CTX sha1_ctx;
#endif
#if INSTALL_RIPEMD160
  RIPEMD160_CTX ripemd160_ctx;
#endif
} hash_context_t;

typedef struct {
  void (*init)(hash_context_t *);
  void (*update)(hash_context_t *, const unsigned char *, const unsigned);
  void (*finish_binary)(hash_context_t *, unsigned char *, const unsigned);
  void (*finish_text)(hash_context_t *, unsigned char *, const unsigned);
  unsigned digest_length;
  const char *name;
} hash_driver_t;

/*
 * Driver for MD5 functions.
 */

#if INSTALL_MD5

static void md5_driver_init (hash_context_t *ctx)
{
  assert(ctx);
  MD5Init(&ctx->md5_ctx);
}

static void md5_driver_update (hash_context_t *ctx, const unsigned char *text, const unsigned length)
{
  assert(ctx && text);
  MD5Update(&ctx->md5_ctx, text, length);
}

static void md5_driver_finish_binary (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 16);
  MD5Final(digest, &ctx->md5_ctx);
}

static void md5_driver_finish_text (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 33);
  MD5End(&ctx->md5_ctx, digest);
}

static hash_driver_t md5_driver = {
  md5_driver_init, md5_driver_update, md5_driver_finish_binary, md5_driver_finish_text, 16, "MD5"
};

#endif

/*
 * Driver for SHA1 functions.
 */

#if INSTALL_SHA1

static void sha1_driver_init (hash_context_t *ctx)
{
  assert(ctx);
  SHA1_Init(&ctx->sha1_ctx);
}

static void sha1_driver_update (hash_context_t *ctx, const unsigned char *text, const unsigned length)
{
  assert(ctx && text);
  SHA1_Update(&ctx->sha1_ctx, text, length);
}

static void sha1_driver_finish_binary (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 20);
  SHA1_Final(digest, &ctx->sha1_ctx);
}

static void sha1_driver_finish_text (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 41);
  SHA1_End(&ctx->sha1_ctx, digest);
}

static hash_driver_t sha1_driver = {
  sha1_driver_init, sha1_driver_update, sha1_driver_finish_binary, sha1_driver_finish_text, 20, "SHA1"
};

#endif

/*
 * Driver for RIPEMD160 functions.
 */

#if INSTALL_RIPEMD160

static void ripemd160_driver_init (hash_context_t *ctx)
{
  assert(ctx);
  RIPEMD160_Init(&ctx->ripemd160_ctx);
}

static void ripemd160_driver_update (hash_context_t *ctx, const unsigned char *text, const unsigned length)
{
  assert(ctx && text);
  RIPEMD160_Update(&ctx->ripemd160_ctx, text, length);
}

static void ripemd160_driver_finish_binary (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 20);
  RIPEMD160_Final(digest, &ctx->ripemd160_ctx);
}

static void ripemd160_driver_finish_text (hash_context_t *ctx, unsigned char *digest, const unsigned length)
{
  assert(ctx && digest && length >= 41);
  RIPEMD160_End(&ctx->ripemd160_ctx, digest);
}

static hash_driver_t ripemd160_driver = {
  ripemd160_driver_init, ripemd160_driver_update, ripemd160_driver_finish_binary, ripemd160_driver_finish_text, 20, "RIPEMD160"
};

#endif

/*
 * Figure out what to use for our default hash driver.
 */

#if INSTALL_MD5
#define HASH_DRIVER_DEFAULT	(&md5_driver)
#elif INSTALL_SHA1
#define HASH_DRIVER_DEFAULT	(&sha1_driver)
#elif INSTALL_RIPEMD160
#define HASH_DRIVER_DEFAULT	(&ripemd160_driver)
#else
#error Gotta have *some* hash driver installed, tovarishch
#endif

/*
 * If you need to ask....
 */

static void lose (const char *jane, const char *msg)
{
  fputs(jane, stderr);
  fputs(": ", stderr);
  perror(msg);
  exit(1);
}

static void usage (const char *jane)
{
  fprintf(stderr, "usage: %s "
#if INSTALL_MD5
	  	  "[-m] "
#endif
#if INSTALL_SHA1
	  	  "[-s] "
#endif
#if INSTALL_RIPEMD160
	  	  "[-r] "
#endif
	  "[-v] key_file < message_file\n", jane);
  exit(1);
}

/*
 * Main program.
 */

int main (int argc, char *argv[])
{
  hash_context_t ctx;
  hash_driver_t *h = HASH_DRIVER_DEFAULT;
  hmac_key_t key;		/* authentication key */
  size_t key_len;		/* length of authentication key */
  hmac_key_t k_ipad;		/* inner padding - key XORd with ipad */
  hmac_key_t k_opad;		/* outer padding - key XORd with opad */
  unsigned char buffer[4096];
  size_t n;
  FILE *f;
  int verbose = 0;
  int argi;

  for (argi = 1; argi < argc && *argv[argi] == '-'; argi++) {
    char *a = argv[argi];
    while (*++a) {
      switch (*a) {
#if INSTALL_MD5
      case 'm': h = &md5_driver; break;
#endif
#if INSTALL_SHA1
      case 's': h = &sha1_driver; break;
#endif
#if INSTALL_RIPEMD160
      case 'r': h = &ripemd160_driver; break;
#endif
      case 'v': verbose = 1; break;
      default: usage(argv[0]);
      }
    }
  }

  if (argi >= argc)
    usage(argv[0]);

  assert(h && sizeof(hmac_key_t) > HMAC_KEY_LENGTH && sizeof(hmac_key_t) > h->digest_length);

  if (verbose)
    fprintf(stderr, "%s: key file: %s\n%s: hash algorithm: %s\n",
	    argv[0], argv[argi],
	    argv[0], h->name);

  /*
   * Read (binary) key from a file.  If it's longer than the HMAC key
   * size, run it through the hash algorithm and use the hash instead
   * of the key itself.
   */
  if (!(f = fopen(argv[argi], "rb")))
    lose(argv[0], "fopen(keyfile)");
  memset(key, 0, sizeof(key));
  if ((n = fread(key, 1, sizeof(key), f)) > HMAC_KEY_LENGTH) {
    h->init(&ctx);
    do {
      h->update(&ctx, key, n);
    } while ((n = fread(key, 1, sizeof(key), f)) > 0);
    memset(key, 0, sizeof(key));
    h->finish_binary(&ctx, key, sizeof(key));
    n = h->digest_length;
  }
  if (ferror(f))
    lose(argv[0], "fread(keyfile)");
  fclose(f);
  key_len = n;

  /*
   * The HMAC transform looks like:
   *
   * H(K XOR opad, H(K XOR ipad, text))
   *
   * where H is the hash algorithm
   * K is an n byte key
   * ipad is the byte 0x36 repeated 64 times
   * opad is the byte 0x5c repeated 64 times
   * and text is the data being protected
   */

  /* Start out by storing key in pads */
  memset(k_ipad, 0, sizeof(k_ipad));
  memset(k_opad, 0, sizeof(k_opad));
  memcpy(k_ipad, key, key_len);
  memcpy(k_opad, key, key_len);

  /* XOR key with ipad and opad values */
  for (n = 0; n < HMAC_KEY_LENGTH; n++) {
    k_ipad[n] ^= 0x36;
    k_opad[n] ^= 0x5c;
  }

  /*
   * perform inner MD5
   */
  h->init(&ctx);			/* init context for 1st pass */
  h->update(&ctx, k_ipad, HMAC_KEY_LENGTH); /* start with inner pad */
  while ((n = fread(buffer, 1, sizeof(buffer), stdin)) > 0)
    h->update(&ctx, buffer, n);		/* fold in the text */    
  if (ferror(stdin))
    lose(argv[0], "fread(stdin)");
  h->finish_binary(&ctx, buffer, sizeof(buffer)); /* finish up 1st pass */

  /*
   * perform outer MD5
   */
  h->init(&ctx);			/* init context for 2nd pass */
  h->update(&ctx, k_opad, HMAC_KEY_LENGTH); /* start with outer pad */
  h->update(&ctx, buffer, h->digest_length); /* then results of 1st hash */
  h->finish_text(&ctx, buffer, sizeof(buffer));	/* finish up 2nd pass */

  puts(buffer);				/* and print out the HMAC */

  return 0;
}
