控制ASP.NET Web API 调用频率
很多的api,例如GitHub’s API 都有流量控制的做法。使用速率限制,以防止在很短的时间量客户端向你的api发出太多的请求.例如,我们可以限制匿名API客户端每小时最多60个请求,而我们可以让更多的经过认证的客户端发出更多的请求。那么asp.net webapi如何实现这样的功能呢?在项目WebApiContrib 上面已经有了一个实现:https://github.com/WebApiContrib/WebAPIContrib/blob/master/src/WebApiContrib/MessageHandlers/ThrottlingHandler.cs ,具有良好的可扩展性。
最简单的方法是使用ThrottlingHandler注册使用简单的参数,例如控制每个用户每分钟,每小时,每天N个请求:当然也可以只做某一个限制
在Global.cs 的 protected void Application_Start()方法里加入如下代码
//访问频率限制,分1分钟,1小时,1天三类,登录的账号*1.5
var throttlingHandler1 = new ThrottlingHandler(
new InMemoryThrottleStore(),
id => 20,
TimeSpan.FromMinutes(1)
);
var throttlingHandler60 = new ThrottlingHandler(
new InMemoryThrottleStore(),
id => 600,
TimeSpan.FromHours(1)
);
//一天按8小时算,16小时休息,
var throttlingHandler1440 = new ThrottlingHandler(
new InMemoryThrottleStore(),
id => 600*8,
TimeSpan.FromDays(1)
);
throttlingHandler1.PrincipalProvider = principalProvider;
throttlingHandler60.PrincipalProvider = principalProvider;
throttlingHandler1440.PrincipalProvider = principalProvider;
GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler1);
GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler60);
GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler1440);
其它各个类的代码如下:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using Hujie.WebApi.Caching;
using log4net;
ThrottlingHandler.cs
--begin--
namespace Hujie.WebApi
{
public class ThrottlingHandler
: DelegatingHandler
{
private readonly IThrottleStore _store;
private readonly Func<string, long> _maxRequestsForUserIdentifier;
private readonly TimeSpan _period;
private readonly string _message;
public IProvidePrincipal PrincipalProvider { get; set; }
private static readonly ILog Logger = LogManager.GetLogger(typeof(TimingActionFilter));
public ThrottlingHandler(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period)
: this(store, maxRequestsForUserIdentifier, period, "The allowed number of requests has been exceeded.")
{
}
public ThrottlingHandler(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period, string message)
{
_store = store;
_maxRequestsForUserIdentifier = maxRequestsForUserIdentifier;
_period = period;
_message = message;
}
protected virtual string GetUserIdentifier(HttpRequestMessage request)
{
return request.GetClientIpAddress();
}
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
var identifier = GetUserIdentifier(request);
var accountId = "";
//取用户登录信息 该部分的实现没有列出
AuthInfoProvider afp = new AuthInfoProvider();
afp.PrincipalProvider = PrincipalProvider;
var principal = await afp.GetAuthInfo(request);
if (principal != null)
{
accountId = principal.Identity.Name;
}
//登录的用户identifier加上账号
if (accountId.Length > 1)
{
identifier = identifier + "-" + accountId.ToString();
}
if (string.IsNullOrEmpty(identifier))
{
return await CreateResponse(request, HttpStatusCode.Forbidden, "Could not identify client.");
}
var maxRequests = _maxRequestsForUserIdentifier(identifier);
//登录的用户的访问次数*1.5
if (accountId.Length > 1)
{
maxRequests = Convert.ToInt32(maxRequests*1.5);
}
ThrottleEntry entry = null;
if (_store.TryGetValue(identifier, out entry))
{
if (entry.PeriodStart + _period < DateTime.UtcNow)
{
_store.Rollover(identifier);
}
}
_store.IncrementRequests(identifier);
if (!_store.TryGetValue(identifier, out entry))
{
return await CreateResponse(request, HttpStatusCode.Forbidden, "Could not identify client.");
}
Task<HttpResponseMessage> response = null;
if (entry.Requests > maxRequests)
{
Logger.Info(identifier + " 访问次数超限 " + maxRequests.ToString());
response = CreateResponse(request, HttpStatusCode.Conflict, _message);
}
else
{
response = base.SendAsync(request, cancellationToken);
}
return await response.ContinueWith(task =>
{
var remaining = maxRequests - entry.Requests;
if (remaining < 0)
{
remaining = 0;
}
var httpResponse = task.Result;
httpResponse.Headers.Add("RateLimit-Limit", maxRequests.ToString());
httpResponse.Headers.Add("RateLimit-Remaining", remaining.ToString());
return httpResponse;
});
}
protected Task<HttpResponseMessage> CreateResponse(HttpRequestMessage request, HttpStatusCode statusCode, string message)
{
var tsc = new TaskCompletionSource<HttpResponseMessage>();
var response = request.CreateResponse(statusCode);
response.ReasonPhrase = message;
response.Content = new StringContent(message);
tsc.SetResult(response);
return tsc.Task;
}
}
}
--end--
HttpRequestMessageExtensions.cs
--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Web;
namespace Hujie.WebApi
{
public static class HttpRequestMessageExtensions
{
private const string HttpContext = "MS_HttpContext";
private const string RemoteEndpointMessage = "System.ServiceModel.Channels.RemoteEndpointMessageProperty";
private const string OwinContext = "MS_OwinContext";
public static bool IsLocal(this HttpRequestMessage request)
{
var localFlag = request.Properties["MS_IsLocal"] as Lazy<bool>;
return localFlag != null && localFlag.Value;
}
public static string GetClientIpAddress(this HttpRequestMessage request)
{
//Web-hosting
if (request.Properties.ContainsKey(HttpContext))
{
dynamic ctx = request.Properties[HttpContext];
if (ctx != null)
{
return ctx.Request.UserHostAddress;
}
}
//Self-hosting
if (request.Properties.ContainsKey(RemoteEndpointMessage))
{
dynamic remoteEndpoint = request.Properties[RemoteEndpointMessage];
if (remoteEndpoint != null)
{
return remoteEndpoint.Address;
}
}
//Owin-hosting
if (request.Properties.ContainsKey(OwinContext))
{
dynamic ctx = request.Properties[OwinContext];
if (ctx != null)
{
return ctx.Request.RemoteIpAddress;
}
}
return null;
}
}
}
--end--
InMemoryThrottleStore.cs
--begin--
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Web;
namespace Hujie.WebApi.Caching
{
public class InMemoryThrottleStore : IThrottleStore
{
private readonly ConcurrentDictionary<string, ThrottleEntry> _throttleStore = new ConcurrentDictionary<string, ThrottleEntry>();
public bool TryGetValue(string key, out ThrottleEntry entry)
{
return _throttleStore.TryGetValue(key, out entry);
}
public void IncrementRequests(string key)
{
_throttleStore.AddOrUpdate(key,
k =>
{
return new ThrottleEntry() { Requests = 1 };
},
(k, e) =>
{
e.Requests++;
return e;
});
}
public void Rollover(string key)
{
ThrottleEntry dummy;
_throttleStore.TryRemove(key, out dummy);
}
public void Clear()
{
_throttleStore.Clear();
}
}
}
--end--
IThrottleStore.cs
--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Web;
namespace Hujie.WebApi.Caching
{
public interface IThrottleStore
{
bool TryGetValue(string key, out ThrottleEntry entry);
void IncrementRequests(string key);
void Rollover(string key);
void Clear();
}
}
--end--
ThrottleEntry.cs
--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Web;
namespace Hujie.WebApi.Caching
{
public class ThrottleEntry
{
public DateTime PeriodStart { get; set; }
public long Requests { get; set; }
public ThrottleEntry()
{
PeriodStart = DateTime.UtcNow;
Requests = 0;
}
}
}
--end--
测试类
ThrottlingHandlerTests.cs
--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using NUnit.Framework;
using Should;
using WebApiContrib.Caching;
using WebApiContrib.MessageHandlers;
namespace WebApiContribTests.MessageHandlers
{
[TestFixture]
public class ThrottlingHandlerTests : MessageHandlerTester
{
[Test]
public void Should_inject_ratelimit_headers_when_limit_not_reached()
{
var handler = GetHandler(100, TimeSpan.FromMinutes(1));
var requestMessage = new HttpRequestMessage(HttpMethod.Get, "foo/bar");
var response = ExecuteRequest(handler, requestMessage);
response.StatusCode.ShouldEqual(HttpStatusCode.OK);
IEnumerable<string> values;
Assert.True(response.Headers.TryGetValues("RateLimit-Limit", out values));
Assert.AreEqual("100", values.First());
Assert.True(response.Headers.TryGetValues("RateLimit-Remaining", out values));
Assert.AreEqual("99", values.First());
}
[Test]
public void Should_throttle_when_limit_reached()
{
var handler = GetHandler(0, TimeSpan.FromMinutes(1));
var requestMessage = new HttpRequestMessage(HttpMethod.Get, "foo/bar");
var response = ExecuteRequest(handler, requestMessage);
response.StatusCode.ShouldEqual(HttpStatusCode.Conflict);
IEnumerable<string> values;
Assert.True(response.Headers.TryGetValues("RateLimit-Limit", out values));
Assert.AreEqual("0", values.First());
Assert.True(response.Headers.TryGetValues("RateLimit-Remaining", out values));
Assert.AreEqual("0", values.First());
}
private ThrottlingHandler GetHandler(long maxRequests, TimeSpan period)
{
return new ThrottlingHandlerWithFixedIdentifier(new InMemoryThrottleStore(), identifier => maxRequests, period);
}
private class ThrottlingHandlerWithFixedIdentifier: ThrottlingHandler
{
public ThrottlingHandlerWithFixedIdentifier(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period) : base(store, maxRequestsForUserIdentifier, period)
{
}
protected override string GetUserIdentifier(HttpRequestMessage request)
{
return "10.0.0.1";
}
}
}
}
--end--
IThrottleStore接口 使用ID +当前的请求数量。InMemoryThrottleStore 只有一个内存中存储,但你可以轻松地扩展实现为分布式缓存或数据库。还可以轻松地自定义ThrottlingHandler的行为,例如我们针对一个ip地址可以更好的进行控制。
Throttling ASP.NET Web API calls
Introducing ASP.NET Web API Throttling handler
Throttling Suite for Web API