控制ASP.NET Web API 调用频率

控制ASP.NET Web API 调用频率


很多的api,例如GitHub’s API 都有流量控制的做法。使用速率限制,以防止在很短的时间量客户端向你的api发出太多的请求.例如,我们可以限制匿名API客户端每小时最多60个请求,而我们可以让更多的经过认证的客户端发出更多的请求。那么asp.net webapi如何实现这样的功能呢?在项目WebApiContrib 上面已经有了一个实现:https://github.com/WebApiContrib/WebAPIContrib/blob/master/src/WebApiContrib/MessageHandlers/ThrottlingHandler.cs ,具有良好的可扩展性。


最简单的方法是使用ThrottlingHandler注册使用简单的参数,例如控制每个用户每分钟,每小时,每天N个请求:当然也可以只做某一个限制
在Global.cs 的  protected void Application_Start()方法里加入如下代码
             //访问频率限制,分1分钟,1小时,1天三类,登录的账号*1.5
            var throttlingHandler1 = new ThrottlingHandler(
                new InMemoryThrottleStore(),
                id => 20,
                TimeSpan.FromMinutes(1)
                );
            var throttlingHandler60 = new ThrottlingHandler(
                new InMemoryThrottleStore(),
                id => 600,
                TimeSpan.FromHours(1)
                );
            //一天按8小时算,16小时休息,
            var throttlingHandler1440 = new ThrottlingHandler(
                new InMemoryThrottleStore(),
                id => 600*8,
                TimeSpan.FromDays(1)
                );
            throttlingHandler1.PrincipalProvider = principalProvider;
            throttlingHandler60.PrincipalProvider = principalProvider;
            throttlingHandler1440.PrincipalProvider = principalProvider;
            GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler1);
            GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler60);
            GlobalConfiguration.Configuration.MessageHandlers.Add(throttlingHandler1440);


其它各个类的代码如下:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using Hujie.WebApi.Caching;
using log4net;

ThrottlingHandler.cs 


--begin--
namespace Hujie.WebApi
{
    public class ThrottlingHandler
        : DelegatingHandler
    {
        private readonly IThrottleStore _store;
        private readonly Func<string, long> _maxRequestsForUserIdentifier;
        private readonly TimeSpan _period;
        private readonly string _message;
        public IProvidePrincipal PrincipalProvider { get; set; }
        private static readonly ILog Logger = LogManager.GetLogger(typeof(TimingActionFilter));


        public ThrottlingHandler(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period)
            : this(store, maxRequestsForUserIdentifier, period, "The allowed number of requests has been exceeded.")
        {
        }

        public ThrottlingHandler(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period, string message)
        {
            _store = store;
            _maxRequestsForUserIdentifier = maxRequestsForUserIdentifier;
            _period = period;
            _message = message;
        }

        protected virtual string GetUserIdentifier(HttpRequestMessage request)
        {
            return request.GetClientIpAddress();
        }

        protected override  async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
        {
            var identifier = GetUserIdentifier(request);

            var accountId = "";
            //取用户登录信息 该部分的实现没有列出
            AuthInfoProvider afp = new AuthInfoProvider();
            afp.PrincipalProvider = PrincipalProvider;


            var principal = await  afp.GetAuthInfo(request);
            if (principal != null)
            {
                accountId = principal.Identity.Name;
            }
            //登录的用户identifier加上账号
            if (accountId.Length > 1)
            {
                identifier = identifier + "-" + accountId.ToString();
            }


            if (string.IsNullOrEmpty(identifier))
            {
                return await CreateResponse(request, HttpStatusCode.Forbidden, "Could not identify client.");
            }


            var maxRequests = _maxRequestsForUserIdentifier(identifier);
            //登录的用户的访问次数*1.5
            if (accountId.Length > 1)
            {
                maxRequests =  Convert.ToInt32(maxRequests*1.5);
            }

            ThrottleEntry entry = null;
            if (_store.TryGetValue(identifier, out entry))
            {
                if (entry.PeriodStart + _period < DateTime.UtcNow)
                {
                    _store.Rollover(identifier);
                }
            }
            _store.IncrementRequests(identifier);
            if (!_store.TryGetValue(identifier, out entry))
            {
                return await CreateResponse(request, HttpStatusCode.Forbidden, "Could not identify client.");
            }


            Task<HttpResponseMessage> response = null;
            if (entry.Requests > maxRequests)
            {
                Logger.Info(identifier + " 访问次数超限 " + maxRequests.ToString());
                response = CreateResponse(request, HttpStatusCode.Conflict, _message);
            }
            else
            {
                response = base.SendAsync(request, cancellationToken);
            }

            return await response.ContinueWith(task =>
            {
                var remaining = maxRequests - entry.Requests;
                if (remaining < 0)
                {
                    remaining = 0;
                }

                var httpResponse = task.Result;
                httpResponse.Headers.Add("RateLimit-Limit", maxRequests.ToString());
                httpResponse.Headers.Add("RateLimit-Remaining", remaining.ToString());


                return httpResponse;
            });
        }

        protected  Task<HttpResponseMessage> CreateResponse(HttpRequestMessage request, HttpStatusCode statusCode, string message)
        {
            var tsc = new TaskCompletionSource<HttpResponseMessage>();
            var response = request.CreateResponse(statusCode);
            response.ReasonPhrase = message;
            response.Content = new StringContent(message);
            tsc.SetResult(response);
            return tsc.Task;
        }

    }
}
--end--

HttpRequestMessageExtensions.cs


--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Web;

namespace Hujie.WebApi
{
    public static class HttpRequestMessageExtensions
    {
        private const string HttpContext = "MS_HttpContext";
        private const string RemoteEndpointMessage = "System.ServiceModel.Channels.RemoteEndpointMessageProperty";
        private const string OwinContext = "MS_OwinContext";


        public static bool IsLocal(this HttpRequestMessage request)
        {
            var localFlag = request.Properties["MS_IsLocal"] as Lazy<bool>;
            return localFlag != null && localFlag.Value;
        }

        public static string GetClientIpAddress(this HttpRequestMessage request)
        {
            //Web-hosting
            if (request.Properties.ContainsKey(HttpContext))
            {
                dynamic ctx = request.Properties[HttpContext];
                if (ctx != null)
                {
                    return ctx.Request.UserHostAddress;
                }
            }
            //Self-hosting
            if (request.Properties.ContainsKey(RemoteEndpointMessage))
            {
                dynamic remoteEndpoint = request.Properties[RemoteEndpointMessage];
                if (remoteEndpoint != null)
                {
                    return remoteEndpoint.Address;
                }
            }
            //Owin-hosting
            if (request.Properties.ContainsKey(OwinContext))
            {
                dynamic ctx = request.Properties[OwinContext];
                if (ctx != null)
                {
                    return ctx.Request.RemoteIpAddress;
                }
            }
            return null;
        }
    }
}
--end--

InMemoryThrottleStore.cs


--begin--
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Web;


namespace Hujie.WebApi.Caching
{
    public class InMemoryThrottleStore : IThrottleStore
    {
        private readonly ConcurrentDictionary<string, ThrottleEntry> _throttleStore = new ConcurrentDictionary<string, ThrottleEntry>();


        public bool TryGetValue(string key, out ThrottleEntry entry)
        {
            return _throttleStore.TryGetValue(key, out entry);
        }


        public void IncrementRequests(string key)
        {
            _throttleStore.AddOrUpdate(key,
                                       k =>
                                       {
                                           return new ThrottleEntry() { Requests = 1 };
                                       },
                                       (k, e) =>
                                       {
                                           e.Requests++;
                                           return e;
                                       });
        }

        public void Rollover(string key)
        {
            ThrottleEntry dummy;
            _throttleStore.TryRemove(key, out dummy);
        }


        public void Clear()
        {
            _throttleStore.Clear();
        }
    }
}
--end--


IThrottleStore.cs


--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Web;


namespace Hujie.WebApi.Caching
{
    public interface IThrottleStore
    {
        bool TryGetValue(string key, out ThrottleEntry entry);
        void IncrementRequests(string key);
        void Rollover(string key);
        void Clear();
    }
}
--end--

ThrottleEntry.cs


--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Web;


namespace Hujie.WebApi.Caching
{
    public class ThrottleEntry
    {
        public DateTime PeriodStart { get; set; }
        public long Requests { get; set; }


        public ThrottleEntry()
        {
            PeriodStart = DateTime.UtcNow;
            Requests = 0;
        }
    }
}
--end--

测试类

ThrottlingHandlerTests.cs


--begin--
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using NUnit.Framework;
using Should;
using WebApiContrib.Caching;
using WebApiContrib.MessageHandlers;


namespace WebApiContribTests.MessageHandlers
{
    [TestFixture]
    public class ThrottlingHandlerTests : MessageHandlerTester
    {
        [Test]
        public void Should_inject_ratelimit_headers_when_limit_not_reached()
        {
            var handler = GetHandler(100, TimeSpan.FromMinutes(1));


            var requestMessage = new HttpRequestMessage(HttpMethod.Get, "foo/bar");
            var response = ExecuteRequest(handler, requestMessage);


            response.StatusCode.ShouldEqual(HttpStatusCode.OK);


            IEnumerable<string> values;
            Assert.True(response.Headers.TryGetValues("RateLimit-Limit", out values));
            Assert.AreEqual("100", values.First());
            Assert.True(response.Headers.TryGetValues("RateLimit-Remaining", out values));
            Assert.AreEqual("99", values.First());
        }


        [Test]
        public void Should_throttle_when_limit_reached()
        {
            var handler = GetHandler(0, TimeSpan.FromMinutes(1));


            var requestMessage = new HttpRequestMessage(HttpMethod.Get, "foo/bar");
            var response = ExecuteRequest(handler, requestMessage);


            response.StatusCode.ShouldEqual(HttpStatusCode.Conflict);


            IEnumerable<string> values;
            Assert.True(response.Headers.TryGetValues("RateLimit-Limit", out values));
            Assert.AreEqual("0", values.First());
            Assert.True(response.Headers.TryGetValues("RateLimit-Remaining", out values));
            Assert.AreEqual("0", values.First());
        }

        private ThrottlingHandler GetHandler(long maxRequests, TimeSpan period)
        {
            return new ThrottlingHandlerWithFixedIdentifier(new InMemoryThrottleStore(), identifier => maxRequests, period);
        }

        private class ThrottlingHandlerWithFixedIdentifier: ThrottlingHandler
        {
            public ThrottlingHandlerWithFixedIdentifier(IThrottleStore store, Func<string, long> maxRequestsForUserIdentifier, TimeSpan period) : base(store, maxRequestsForUserIdentifier, period)
            {
            }

            protected override string GetUserIdentifier(HttpRequestMessage request)
            {
                return "10.0.0.1";
            }
        }
    }
}
--end--


IThrottleStore接口 使用ID +当前的请求数量。InMemoryThrottleStore 只有一个内存中存储,但你可以轻松地扩展实现为分布式缓存或数据库。还可以轻松地自定义ThrottlingHandler的行为,例如我们针对一个ip地址可以更好的进行控制。

Throttling ASP.NET Web API calls
Introducing ASP.NET Web API Throttling handler
Throttling Suite for Web API

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值