#include <string.h>

#include "misc/misc.h"

#include "dns.h"

static void shuffle_ips(ipv4addr* ips, unsigned count)
{
  ipv4addr tmp;
  while (count > 1) {
    unsigned i = random_scale(count--);
    memcpy(&tmp, &ips[i], 4);
    memcpy(&ips[i], &ips[count], 4);
    memcpy(&ips[count], &tmp, 4);
  }
}

static unsigned get_ips(const str* packet, ipv4addr* ips, unsigned maxips)
{
  str rrs = {0,0,0};
  ipv4addr iplist[DNS_MAX_IPS];
  unsigned count;
  const char* ptr;
  const char* end;
  unsigned len;
  if (!dns_packet_get(packet, &rrs, DNS_TYPE_A)) {
    str_free(&rrs);
    return 0;
  }
  /* Extract the IPs from the RR list into a temporary list */
  ptr = rrs.s;
  end = rrs.s + rrs.len;
  for (count = 0; count < DNS_MAX_IPS && ptr+2 < end; ptr += len) {
    len = uint16_get_msb(ptr);
    ptr += 2;
    if (ptr + len > end) break;
    if (len == 4)
      memcpy(iplist + count++, ptr, 4);
  }
  str_free(&rrs);
  shuffle_ips(iplist, count);
  /* Copy the result into the given array */
  if (count > maxips) count = maxips;
  memcpy(ips, iplist, 4 * count);
  return count;
}

static int ipv4_base(int (*queryfn)(const char*, int, str*),
		     const char* domain, ipv4addr* ips, unsigned maxips)
{
  str packet = {0,0,0};
  int count;
  const char* end;
  if (maxips == 0) return 0;
  if ((end = ipv4_scan(domain, ips)) != 0 && *end == 0) return 1;
  switch (queryfn(domain, DNS_TYPE_A, &packet)) {
  case -1: count = -1; break;
  case 0: count = 0; break;
  default: count = get_ips(&packet, ips, maxips);
  }
  str_free(&packet);
  return count;
}

int dns_ipv4(const char* domain, ipv4addr* ips, unsigned maxips)
{
  return ipv4_base(dns_query, domain, ips, maxips);
}

int dns_ipv4_fqdn(const char* domain, ipv4addr* ips, unsigned maxips)
{
  return ipv4_base(dns_query_fqdn, domain, ips, maxips);
}
