diff --git pdns/auth-main.cc pdns/auth-main.cc index 24ac5e509..fc5d1cfab 100644 --- pdns/auth-main.cc +++ pdns/auth-main.cc @@ -334,10 +334,11 @@ static void declareArguments() ::arg().set("default-catalog-zone", "Catalog zone to assign newly created primary zones (via the API) to") = ""; #ifdef ENABLE_GSS_TSIG ::arg().setSwitch("enable-gss-tsig", "Enable GSS TSIG processing") = "no"; + ::arg().set("gss-max-contexts", "The maximum number of simultaneous GSS contexts allowed") = "1000"; #endif ::arg().setDefaults(); } static time_t s_start = time(nullptr); @@ -710,10 +711,13 @@ static void mainthread() g_luaConsistentHashesCleanupInterval = ::arg().asNum("lua-consistent-hashes-cleanup-interval"); g_luaHealthChecksExpireDelay = ::arg().asNum("lua-health-checks-expire-delay"); #endif #ifdef ENABLE_GSS_TSIG g_doGssTSIG = ::arg().mustDo("enable-gss-tsig"); + if (g_doGssTSIG) { + GssContext::s_maxGssContexts = ::arg().asNum("gss-max-contexts"); + } #endif DNSPacket::s_udpTruncationThreshold = std::max(512, ::arg().asNum("udp-truncation-threshold")); DNSPacket::s_doEDNSSubnetProcessing = ::arg().mustDo("edns-subnet-processing"); PacketHandler::s_SVCAutohints = ::arg().mustDo("svc-autohints"); diff --git pdns/dnssecinfra.cc pdns/dnssecinfra.cc index 0c67004a0..77b7c6802 100644 --- pdns/dnssecinfra.cc +++ pdns/dnssecinfra.cc @@ -864,11 +864,10 @@ bool validateTSIG(const std::string& packet, size_t sigPos, const TSIGTriplet& t string tsigMsg; tsigMsg = makeTSIGMessageFromTSIGPacket(packet, sigPos, tt.name, trc, previousMAC, timersOnly, dnsHeaderOffset); if (algo == TSIG_GSS) { - GssContext gssctx(tt.name); if (!gss_verify_signature(tt.name, tsigMsg, theirMAC)) { throw std::runtime_error("Signature with TSIG key '"+tt.name.toLogString()+"' failed to validate"); } } else { string ourMac = calculateHMAC(tt.secret, tsigMsg, algo); diff --git pdns/gss_context.cc pdns/gss_context.cc index 1297356a0..ccc6e12af 100644 --- pdns/gss_context.cc +++ pdns/gss_context.cc @@ -52,10 +52,12 @@ GssContextError GssContext::getError() { return GSS_CONTEXT_UNSUPPORTED; } #include "lock.hh" #define TSIG_GSS_EXPIRE_INTERVAL 60 +unsigned int GssContext::s_maxGssContexts{1000}; + class GssCredential : boost::noncopyable { public: GssCredential(const std::string& name, const gss_cred_usage_t usage) : d_nameS(name), d_usage(usage) @@ -133,11 +135,11 @@ public: }; // GssCredential static LockGuarded>> s_gss_accept_creds; static LockGuarded>> s_gss_init_creds; -class GssSecContext : boost::noncopyable +class GssSecContext { public: GssSecContext(std::shared_ptr cred) { if (!cred->valid()) { @@ -170,11 +172,11 @@ public: GssStateComplete, GssStateError } d_state{GssStateInitial}; }; // GssSecContext -static LockGuarded>> s_gss_sec_context; +static LockGuarded>>> s_gss_sec_context; template static void doExpire(T& m, time_t now) { auto lock = m.lock(); @@ -186,21 +188,41 @@ static void doExpire(T& m, time_t now) ++i; } } } +// Same as above, for s_gss_sec_context +template +static void doExpireL(T& m, time_t now) +{ + auto lock = m.lock(); + for (auto i = lock->begin(); i != lock->end();) { + time_t expiretime{0}; + { + auto ctx = i->second->lock(); + expiretime = ctx->d_expires; + } + if (now > expiretime) { + i = lock->erase(i); + } + else { + ++i; + } + } +} + static void expire() { - static time_t s_last_expired; + static std::atomic s_last_expired; time_t now = time(nullptr); if (now - s_last_expired < TSIG_GSS_EXPIRE_INTERVAL) { return; } s_last_expired = now; doExpire(s_gss_init_creds, now); doExpire(s_gss_accept_creds, now); - doExpire(s_gss_sec_context, now); + doExpireL(s_gss_sec_context, now); } bool GssContext::supported() { return true; } void GssContext::initialize() @@ -235,22 +257,60 @@ void GssContext::setLabel(const DNSName& label) d_label = label; auto lock = s_gss_sec_context.lock(); auto it = lock->find(d_label); if (it != lock->end()) { d_secctx = it->second; - d_type = d_secctx->d_type; + auto ctx = d_secctx->lock(); + d_type = ctx->d_type; } } bool GssContext::expired() { - return (!d_secctx || (d_secctx->d_expires > -1 && d_secctx->d_expires < time(nullptr))); + if (!d_secctx) { + return true; + } + auto ctx = d_secctx->lock(); + return (ctx->d_expires > -1 && ctx->d_expires < time(nullptr)); } bool GssContext::valid() { - return (d_secctx && !expired() && d_secctx->d_state == GssSecContext::GssStateComplete); + if (expired()) { + return false; + } + auto ctx = d_secctx->lock(); + return ctx->d_state == GssSecContext::GssStateComplete; +} + +bool GssContext::createOrReuseContext(std::shared_ptr cred) +{ + // see if we can find a context in non-completed state + if (d_secctx) { + auto ctx = d_secctx->lock(); + if (ctx->d_state != GssSecContext::GssStateNegotiate) { + d_error = GSS_CONTEXT_INVALID; + return false; + } + } + else { + // make context + auto lock = s_gss_sec_context.lock(); + if (lock->size() == s_maxGssContexts) { + d_error = GSS_CONTEXT_LIMIT_REACHED; + d_gss_errors.push_back("Limit of concurrent GSS contexts reached"); + return false; + } + d_secctx = std::make_shared>(cred); + { + auto ctx = d_secctx->lock(); + ctx->d_state = GssSecContext::GssStateNegotiate; + ctx->d_type = d_type; + } + (*lock)[d_label] = d_secctx; + } + return true; } bool GssContext::init(const std::string& input, std::string& output) { expire(); @@ -275,57 +335,49 @@ bool GssContext::init(const std::string& input, std::string& output) it = lock->emplace(d_localPrincipal, std::make_shared(d_localPrincipal, GSS_C_INITIATE)).first; } cred = it->second; } - // see if we can find a context in non-completed state - if (d_secctx) { - if (d_secctx->d_state != GssSecContext::GssStateNegotiate) { - d_error = GSS_CONTEXT_INVALID; - return false; - } - } - else { - // make context - auto lock = s_gss_sec_context.lock(); - d_secctx = std::make_shared(cred); - d_secctx->d_state = GssSecContext::GssStateNegotiate; - d_secctx->d_type = d_type; - (*lock)[d_label] = d_secctx; + if (!createOrReuseContext(cred)) { + return false; } recv_tok.length = input.size(); recv_tok.value = const_cast(static_cast(input.c_str())); - if (!d_peerPrincipal.empty()) { - buffer.value = const_cast(static_cast(d_peerPrincipal.c_str())); - buffer.length = d_peerPrincipal.size(); - maj = gss_import_name(&min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &(d_secctx->d_peer_name)); - if (maj != GSS_S_COMPLETE) { - processError("gss_import_name", maj, min); - return false; + { + auto ctx = d_secctx->lock(); + + if (!d_peerPrincipal.empty()) { + buffer.value = const_cast(static_cast(d_peerPrincipal.c_str())); + buffer.length = d_peerPrincipal.size(); + maj = gss_import_name(&min, &buffer, (gss_OID)GSS_KRB5_NT_PRINCIPAL_NAME, &(ctx->d_peer_name)); + if (maj != GSS_S_COMPLETE) { + processError("gss_import_name", maj, min); + return false; + } } - } - maj = gss_init_sec_context(&min, cred->d_cred, &d_secctx->d_ctx, d_secctx->d_peer_name, GSS_C_NO_OID, GSS_C_MUTUAL_FLAG | GSS_C_REPLAY_FLAG, GSS_C_INDEFINITE, GSS_C_NO_CHANNEL_BINDINGS, &recv_tok, nullptr, &send_tok, &flags, &expires); + maj = gss_init_sec_context(&min, cred->d_cred, &ctx->d_ctx, ctx->d_peer_name, GSS_C_NO_OID, GSS_C_MUTUAL_FLAG | GSS_C_REPLAY_FLAG, GSS_C_INDEFINITE, GSS_C_NO_CHANNEL_BINDINGS, &recv_tok, nullptr, &send_tok, &flags, &expires); - if (send_tok.length > 0) { - output.assign(static_cast(send_tok.value), send_tok.length); - tmp_maj = gss_release_buffer(&tmp_min, &send_tok); - } + if (send_tok.length > 0) { + output.assign(static_cast(send_tok.value), send_tok.length); + tmp_maj = gss_release_buffer(&tmp_min, &send_tok); + } - if (maj == GSS_S_COMPLETE) { - // We do not want forever - if (expires == GSS_C_INDEFINITE) { - expires = 60; + if (maj == GSS_S_COMPLETE) { + // We do not want forever + if (expires == GSS_C_INDEFINITE) { + expires = 60; + } + ctx->d_expires = time(nullptr) + expires; + ctx->d_state = GssSecContext::GssStateComplete; + return true; + } + else if (maj != GSS_S_CONTINUE_NEEDED) { + processError("gss_init_sec_context", maj, min); } - d_secctx->d_expires = time(nullptr) + expires; - d_secctx->d_state = GssSecContext::GssStateComplete; - return true; - } - else if (maj != GSS_S_CONTINUE_NEEDED) { - processError("gss_init_sec_context", maj, min); } return (maj == GSS_S_CONTINUE_NEEDED); } @@ -353,48 +405,40 @@ bool GssContext::accept(const std::string& input, std::string& output) it = lock->emplace(d_localPrincipal, std::make_shared(d_localPrincipal, GSS_C_ACCEPT)).first; } cred = it->second; } - // see if we can find a context in non-completed state - if (d_secctx) { - if (d_secctx->d_state != GssSecContext::GssStateNegotiate) { - d_error = GSS_CONTEXT_INVALID; - return false; - } - } - else { - // make context - auto lock = s_gss_sec_context.lock(); - d_secctx = std::make_shared(cred); - d_secctx->d_state = GssSecContext::GssStateNegotiate; - d_secctx->d_type = d_type; - (*lock)[d_label] = d_secctx; + if (!createOrReuseContext(cred)) { + return false; } recv_tok.length = input.size(); recv_tok.value = const_cast(static_cast(input.c_str())); - maj = gss_accept_sec_context(&min, &d_secctx->d_ctx, cred->d_cred, &recv_tok, GSS_C_NO_CHANNEL_BINDINGS, &d_secctx->d_peer_name, nullptr, &send_tok, &flags, &expires, nullptr); + { + auto ctx = d_secctx->lock(); + maj = gss_accept_sec_context(&min, &ctx->d_ctx, cred->d_cred, &recv_tok, GSS_C_NO_CHANNEL_BINDINGS, &ctx->d_peer_name, nullptr, &send_tok, &flags, &expires, nullptr); - if (send_tok.length > 0) { - output.assign(static_cast(send_tok.value), send_tok.length); - tmp_maj = gss_release_buffer(&tmp_min, &send_tok); - } + if (send_tok.length > 0) { + output.assign(static_cast(send_tok.value), send_tok.length); + tmp_maj = gss_release_buffer(&tmp_min, &send_tok); + } - if (maj == GSS_S_COMPLETE) { - // We do not want forever - if (expires == GSS_C_INDEFINITE) { - expires = 60; + if (maj == GSS_S_COMPLETE) { + // We do not want forever + if (expires == GSS_C_INDEFINITE) { + expires = 60; + } + ctx->d_expires = time(nullptr) + expires; + ctx->d_state = GssSecContext::GssStateComplete; + return true; + } + else if (maj != GSS_S_CONTINUE_NEEDED) { + processError("gss_accept_sec_context", maj, min); } - d_secctx->d_expires = time(nullptr) + expires; - d_secctx->d_state = GssSecContext::GssStateComplete; - return true; - } - else if (maj != GSS_S_CONTINUE_NEEDED) { - processError("gss_accept_sec_context", maj, min); } + return (maj == GSS_S_CONTINUE_NEEDED); }; bool GssContext::sign(const std::string& input, std::string& output) { @@ -405,11 +449,14 @@ bool GssContext::sign(const std::string& input, std::string& output) gss_buffer_desc send_tok = GSS_C_EMPTY_BUFFER; recv_tok.length = input.size(); recv_tok.value = const_cast(static_cast(input.c_str())); - maj = gss_get_mic(&min, d_secctx->d_ctx, GSS_C_QOP_DEFAULT, &recv_tok, &send_tok); + { + auto ctx = d_secctx->lock(); + maj = gss_get_mic(&min, ctx->d_ctx, GSS_C_QOP_DEFAULT, &recv_tok, &send_tok); + } if (send_tok.length > 0) { output.assign(static_cast(send_tok.value), send_tok.length); tmp_maj = gss_release_buffer(&tmp_min, &send_tok); } @@ -431,11 +478,14 @@ bool GssContext::verify(const std::string& input, const std::string& signature) recv_tok.length = input.size(); recv_tok.value = const_cast(static_cast(input.c_str())); sign_tok.length = signature.size(); sign_tok.value = const_cast(static_cast(signature.c_str())); - maj = gss_verify_mic(&min, d_secctx->d_ctx, &recv_tok, &sign_tok, nullptr); + { + auto ctx = d_secctx->lock(); + maj = gss_verify_mic(&min, ctx->d_ctx, &recv_tok, &sign_tok, nullptr); + } if (maj != GSS_S_COMPLETE) { processError("gss_get_mic", maj, min); } @@ -470,24 +520,27 @@ void GssContext::setPeerPrincipal(const std::string& name) bool GssContext::getPeerPrincipal(std::string& name) { gss_buffer_desc value; OM_uint32 maj, min; - if (d_secctx->d_peer_name != GSS_C_NO_NAME) { - maj = gss_display_name(&min, d_secctx->d_peer_name, &value, nullptr); - if (maj == GSS_S_COMPLETE && value.length > 0) { - name.assign(static_cast(value.value), value.length); - maj = gss_release_buffer(&min, &value); - return true; + { + auto ctx = d_secctx->lock(); + if (ctx->d_peer_name != GSS_C_NO_NAME) { + maj = gss_display_name(&min, ctx->d_peer_name, &value, nullptr); + if (maj == GSS_S_COMPLETE && value.length > 0) { + name.assign(static_cast(value.value), value.length); + maj = gss_release_buffer(&min, &value); + return true; + } + else { + return false; + } } else { return false; } } - else { - return false; - } } std::tuple GssContext::getCounts() { return {s_gss_init_creds.lock()->size(), s_gss_accept_creds.lock()->size(), s_gss_sec_context.lock()->size()}; diff --git pdns/gss_context.hh pdns/gss_context.hh index ba2e545e9..fc2aea280 100644 --- pdns/gss_context.hh +++ pdns/gss_context.hh @@ -26,10 +26,11 @@ #endif #include "namespaces.hh" #include "pdnsexception.hh" #include "dns.hh" +#include "lock.hh" #ifdef ENABLE_GSS_TSIG #include #include extern bool g_doGssTSIG; @@ -42,11 +43,12 @@ enum GssContextError GSS_CONTEXT_UNSUPPORTED, GSS_CONTEXT_NOT_FOUND, GSS_CONTEXT_NOT_INITIALIZED, GSS_CONTEXT_INVALID, GSS_CONTEXT_EXPIRED, - GSS_CONTEXT_ALREADY_INITIALIZED + GSS_CONTEXT_ALREADY_INITIALIZED, + GSS_CONTEXT_LIMIT_REACHED, }; //! GSS context types enum GssContextType { @@ -54,10 +56,11 @@ enum GssContextType GSS_CONTEXT_INIT, GSS_CONTEXT_ACCEPT }; class GssSecContext; +class GssCredential; /*! Class for representing GSS names, such as host/host.domain.com@REALM. */ class GssName { @@ -191,22 +194,26 @@ public: bool sign(const std::string& input, std::string& output); // getErrorStrings() { return d_gss_errors; } // cred); #endif DNSName d_label; // d_gss_errors; // d_secctx; //> d_secctx; //