summaryrefslogtreecommitdiff
path: root/util/thread_timeout.py
blob: 9a2637c845fd1885d386474612f849c3faddaddb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from threading import Thread
import inspect
import ctypes
from functools import wraps


def _async_raise(tid, exctype):
    """raises the exception, performs cleanup if needed"""
    tid = ctypes.c_long(tid)
    if not inspect.isclass(exctype):
        exctype = type(exctype)
    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
    if res == 0:
        raise ValueError("invalid thread id")
    elif res != 1:
        # """if it returns a number greater than one, you're in trouble,
        # and you should call it again with exc=NULL to revert the effect"""
        ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
        raise SystemError("PyThreadState_SetAsyncExc failed")


def stop_thread(thread):
    _async_raise(thread.ident, SystemExit)


class TimeoutException(Exception):
    # print("timeout!")
    pass


ThreadStop = stop_thread


def time_limited(timeout):
    def decorator(function):
        @wraps(function)
        def wrapped_function(*args, **kwargs):
            class TimeLimited(Thread):
                def __init__(self):
                    Thread.__init__(self)
                    self.error = None
                    self.result = None

                def run(self):
                    self.result = function(*args, **kwargs)

                def stop(self):
                    if self.is_alive():
                        ThreadStop(self)

            t = TimeLimited()
            t.start()
            t.join(timeout)
            if isinstance(t.error, TimeoutException):
                t.stop()
                raise TimeoutException('timeout for %s' % (repr(function)))
            if t.is_alive():
                t.stop()
                raise TimeoutException('timeout for %s' % (repr(function)))
            if t.error is None:
                return t.result

        return wrapped_function

    return decorator