栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

Celery 自定义消费器

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Celery 自定义消费器

Celery 自定义消息消费者实践

其他博文链接

celery文档翻译-自定义消息消费者

文章目录
  • Celery 自定义消息消费者实践
      • 前言
      • 1.继承Task
      • 2.自定义消息消费者
        • 2.1 celery文档描述
        • 2.2 TokenBucket实现
        • 2.3 定义新算法
        • 2.4 改动bucket_for_task

前言

celery的限流策略只针对单个worker,实际使用的时候worker是分布式启动,单worker不能满足需求,因此想在celery原先的限流策略上使用redis来实现分布式worker限流。

最开始的想法是自定义task去继承celery.Task,在调用__call__()函数的时候插入限流判断,也能实现功能。

后来无意间发现celery支持自定义消息消费者,竟然baidu、google不到有人发布相关的博客,不给我copy的机会,然后自己捣鼓了一下,也算实现了功能。这里会将两种方式的代码都贴出来。

1.继承Task

这一步的时候为了偷懒用了一个第三方库walrus实现redis限流功能,具体实现与使用请参考 walrus 使用 .

from celery import Task as _Task
from walrus import *

class baseTask(_Task):

    rate = None
    rate_limit_max_retry = 100000
    rate_limit_sleep_time = 0.1
    throttle_lock = None
    walrus_session = None
    RATE_KEY = 'RATE_LIMIT'


    def __call__(self, *args, **kwargs):
        """In celery task this function call the run method, here you can
        set some environment variable before the run of the task"""
        self.check_rate()
        return self.run(*args, **kwargs)

    def check_rate(self):
        """
        循环访问限流器,直到可以执行则返回.
        """
        if self.rate:
            # init walrus session
            if self.walrus_session is None:
                # broker_url指的是celery配置使用的中间件,未支持rabbitmq
                redis_conf = get_redis_by_url(Config.celery_conf['broker_url'])
                if redis_conf is None:
                    raise Exception('无redis中间件配置')
                self.walrus_session = Database(**redis_conf)
                self.lock = self.walrus_session.lock(f"{self.RATE_KEY}:{self.name}")

            # init walrus limiter
            limit, per = parse_rate(self.rate)
            raise_exe = True
            
            rate_limit = self.walrus_session.rate_limit(self.RATE_KEY, limit=limit, per=per)
            

            for i in range(self.rate_limit_max_retry):
                with self.lock:
                    f = rate_limit.limit(self.name)

                if f:
                    sleep_time = per + int(per/limit) * i
                    print(f"Task {self.name}[{self.request.id}] in sleep {sleep_time}s .", end="")
                    time.sleep(sleep_time)
                else:
                    raise_exe = False
                    break
            if raise_exe:
                raise Exception('超出最大重复次数')
            return 0

这种实现比较简单粗暴,不断循环访问redis,虽然也有控制sleep_time,但是精度控制不好。

2.自定义消息消费者

这里小声吐槽一下celery的作者,作者确实很牛逼,就是这源码对小白不太友好,各种动态、鸭子类型,需要看一些python进阶书才能稍微看懂一点。暴风哭泣。

2.1 celery文档描述
  • consumer.bucket_for_task(type, Bucket=TokenBucket)
    使用 task.rate_limit 属性为一个任务创建速率限制bucket。

celery官网的文档中介绍了这个函数,因此先写了一个demo

from celery import Celery

from celery.worker.consumer.tasks import Tasks
from celery import bootsteps
from collections import defaultdict
from celery.utils.time import rate
from kombu.utils.limits import TokenBucket

from celery.worker.consumer.consumer import Consumer

app = Celery()

# 从celeryconfig.py中读取celery配置
app.config_from_object('celeryconfig')

class Step(bootsteps.StartStopStep):
  requires = {'celery.worker.consumer.tasks:Tasks'}

  def start(self, c):
    c.bucket_for_task = self.bucket_for_task
    c.reset_rate_limits()
    if self.obj:
      return self.obj.start()

  def bucket_for_task(self, parent):
    print(parent)
    limit = rate(getattr(parent, 'rate_limit', None))
    return TokenBucket(limit, capacity=1) if limit else None

  def reset_rate_limits(self):
    self.task_buckets.update(
      (n, self.bucket_for_task(t)) for n, t in self.app.tasks.items()
    )
 
# 注册
app.steps['consumer'].add(Step)

# demo Task
@app.task(rate_limit='1/s')
def test_rate_limit():
    print('hello world') 」

通过日志可以知道,celery会为每一个task设置一个bucket,celery用了默认的TokenBucket。**TokenBucket**实现了 token bucket 算法,只要遵循相同接口并且定义了这两个方法的任何算法都可以被使用(鸭子类型)。因此不需要继承,只需要实现一个类拥有相同接口方法就可使用。

2.2 TokenBucket实现

通过2.1知道,我们只需要定义一个新的类,实现了TokenBucket的方法就可以替换TokenBucket ,所以这里先看看TokenBucket有些什么东西。

"""Token bucket implementation for rate limiting."""

from collections import deque
from time import monotonic

__all__ = ('TokenBucket',)


class TokenBucket:
    """Token Bucket Algorithm.

    See Also:
        https://en.wikipedia.org/wiki/Token_Bucket

        Most of this code was stolen from an entry in the ASPN Python Cookbook:
        https://code.activestate.com/recipes/511490/

    Warning:
        Thread Safety: This implementation is not thread safe.
        Access to a `TokenBucket` instance should occur within the critical
        section of any multithreaded code.
    """

    #: The rate in tokens/second that the bucket will be refilled.
    fill_rate = None

    #: Maximum number of tokens in the bucket.
    capacity = 1

    #: Timestamp of the last time a token was taken out of the bucket.
    timestamp = None

    def __init__(self, fill_rate, capacity=1):
        self.capacity = float(capacity)
        self._tokens = capacity
        self.fill_rate = float(fill_rate)
        self.timestamp = monotonic()
        self.contents = deque()

    def add(self, item):
        self.contents.append(item)

    def pop(self):
        return self.contents.popleft()

    def clear_pending(self):
        self.contents.clear()

    def can_consume(self, tokens=1):
        """Check if one or more tokens can be consumed.

        Returns:
            bool: true if the number of tokens can be consumed
                from the bucket.  If they can be consumed, a call will also
                consume the requested number of tokens from the bucket.
                Calls will only consume `tokens` (the number requested)
                or zero tokens -- it will never consume a partial number
                of tokens.
        """
        if tokens <= self._get_tokens():
            self._tokens -= tokens
            return True
        return False

    def expected_time(self, tokens=1):
        """Return estimated time of token availability.

        Returns:
            float: the time in seconds.
        """
        _tokens = self._get_tokens()
        tokens = max(tokens, _tokens)
        return (tokens - _tokens) / self.fill_rate

    def _get_tokens(self):
        if self._tokens < self.capacity:
            now = monotonic()
            delta = self.fill_rate * (now - self.timestamp)
            self._tokens = min(self.capacity, self._tokens + delta)
            self.timestamp = now
        return self._tokens

简单追踪一下日志可以发现,celery消费者在接收到一个任务之后会先调用add方法,再调用can_consume方法,去判断tokens个任务能否被消费:返回为True时,调用pop方法,拿到任务进行消费。返回False的时候,调用expected_time方法获取下一次访问需要等待的时间(单位/s)。

因此这里我不想改动add,pop逻辑,只改动can_consume、expected_time方法。

2.3 定义新算法

这里贴一部分代码。

expected_time方法我是照搬了一个开源的代码:https://github.com/freelawproject/courtlistener/blob/main/cl/lib/celery_utils.py

  • can_consume:throttle_key记录执行次数,在第一个任务执行的时候设置过期时间=duration;如果能被执行,throttle_key自增tokens,并且返回True,不能被执行返回False.
  • expected_time:schedule_key是None 或者是过去的时间:schedule_key=now+ttl(throttle_key),等待时间=ttl(throttle_key);schedule_key > now :滑动窗口,schedule_key=schedule_key+(duration / task_count),等待时间=new_schedule_key - now;
class ThrottleLimit:

    def can_consume(self, tokens=1):
        """Check if one or more tokens can be consumed.

        Returns:
            bool: true if the number of tokens can be consumed
                from the bucket.  If they can be consumed, a call will also
                consume the requested number of tokens from the bucket.
                Calls will only consume `tokens` (the number requested)
                or zero tokens -- it will never consume a partial number
                of tokens.
        """
        with self.lock:
            # Check the count in redis
            actual_task_count = self.client.get(self.throttle_key)
            if actual_task_count is None and tokens <= self.task_count:
                # No key. Set the value to 1 and set the ttl of the key.
                self.client.set(self.throttle_key, tokens, ex=self.duration)
                return True

            actual_task_count = int(actual_task_count)

            # Key found. Check if we should throttle.
            if (tokens + actual_task_count) <= self.task_count:
                # We're OK to run the task. Increment our counter, and say things are
                # OK by returning 0.
                new_count = self.client.incr(self.throttle_key, tokens)
                if new_count == tokens:
                    # Safety check. If the count is tokens after incrementing, that means we
                    # created the key via the incr command. This can happen when it
                    # expires between when we `get` its value up above and when we
                    # increment it here. If that happens, it lacks a ttl! Set one.
                    #
                    # N.B. There's no need to worry about a race condition between our
                    # incr above, and the `expire` line here b/c without a ttl on this
                    # key, it can't expire between these two commands.
                    self.client.expire(self.throttle_key, self.duration)
                return True

        return False

    def _set_for_next_window(self, n) -> float:
        """Set the schedule for the next window to start as soon as the current one
        runs out.
        """
        ttl = self.client.ttl(self.throttle_key)
        if ttl < 0:
            # Race condition. The key expired (-2) or doesn't have a
            # TTL (-1). Don't delay; run the task.
            return 0
        self.client.set(self.schedule_key, str(n + timedelta(seconds=ttl)))
        return ttl

    def expected_time(self, tokens=1):
        """Return estimated time of token availability.

        Returns:
            float: the time in seconds.
        """
        with self.schedule_lock:
            n = datetime.now()
            delay = self.client.get(self.schedule_key)
            if delay is None:
                return self._set_for_next_window(n)

            # # We have a delay, so use it if it's in the future
            delay = parser.parse(delay)
            if delay < n:
                # Delay is in the past. Run the task when the current throttle expires.
                return self._set_for_next_window(n)

            # # Delay is in the future; use it and supplement it
            new_time = delay + timedelta(seconds=self.duration / self.task_count)
            self.client.set(self.schedule_key, str(new_time))
            return (new_time - n).total_seconds()

这里简单贴一张我画的流程图,描述的是算法原作者的代码,画的简陋勿喷。

2.4 改动bucket_for_task

ThrottleLimit是需要使用redis中间件,实际情况也只有自定义的任务才需要借助redis限流,而celery自己的任务就不需要去使用redis限流了,因此我又改动了一下bucket_for_task方法的逻辑。

    def bucket_for_task(self, parent):
        fill_rate = rate(getattr(parent, 'rate_limit', None))
        broker_url = parent._get_app().conf['broker_url']

        # celery自身的任务走 TokenBucket
        # 中间件不是redis的走 TokenBucket
        if parent.name.startswith('celery.') or 'redis' not in broker_url:
            return TokenBucket(fill_rate, capacity=1) if fill_rate else None
        limit = parse_rate(getattr(parent, 'rate_limit', None))
        return ThrottleLimit(parent, *limit) if limit else None
        

parse_rate是解析形如rate_limit字符串的方法,这里可以自由发挥。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/351832.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号