Thread Safety
This lessons discusses the concept of thread safety.
Thread Safety
The primary motivation behind using multiple threads is improving program performance that may be measured with metrics such as throughput, responsiveness, latency, etc. Whenever threads are introduced in a program, the shared state amongst the threads becomes vulnerable to corruption. If a class or a program has immutable state then the class is necessarily thread-safe. Similarly, the shared state in an application where the same thread mutates the state using an operation that translates into an atomic bytecode instruction can be safely read by multiple reader threads. In contrast, a sole writer thread mutating the shared state using several atomic bytecode instructions isn't a thread-safe scenario for reader threads. Most multi-threaded setups require caution when interacting with shared state. As a corollary, the composition of two thread-safe classes doesn't guarantee thread-safety.
Atomicity
Consider the below snippet:
count = 0
def increment():
global count
count += 1
The above code will work flawlessly when it is executed by a single thread. However, if there are two or more threads involved, things get tricky. The key to realize is that the statement count += 1
isn't atomic. A thread can't increment the variable atomically, i.e. there doesn't exist a single bytecode instruction that can increment the count variable. Let's examine the bytecode generated for our snippet above.
import discount = 0def increment():global countcount += 1# prints the bytecodedis.dis(increment)
Generated Byte Code
3 LOAD_CONST 1 (1)
6 INPLACE_ADD
7 STORE_GLOBAL 0 (count)
10 LOAD_CONST 0 (None)
13 RETURN_VALUE
The seemingly single line statement expands into multiple bytecode instructions. When two threads invoke the increment()
method it is possible that the first thread is switched out by the Python interpreter just before the third INPLACE_ADD instruction is executed. Now the second thread comes along and executes all the six bytecode instructions in one go. When the first thread is rescheduled by the interpreter, it executes the third line but the value the thread holds is stale causing it to incorrectly update the count
variable.
Programming languages provide constructs such as mutexes and locks to help developers guard sections of code that must be executed sequentially by multiple threads. Guarding shared data is one aspect of multi-threaded programs. The other aspect is coordination and cooperation amongst threads. Again, languages provide mechanisms to facilitate threads to work cooperatively towards a common goal. These include semaphores, barriers etc.
Thread Unsafe Class
Take a minute to go through the following program. It increments an object of class Counter
using 5 threads. Each thread increments the object a hundred thousand times. The final value of the counter should be half a million (500,000). If you run the program enough times, you'll sometimes get the correct summation, and at others, you'll get an incorrect value.
from threading import Threadimport sysclass Counter:def __init__(self):self.count = 0def increment(self):for _ in range(100000):self.count += 1if __name__ == "__main__":# Sets the thread switch intervalsys.setswitchinterval(0.005)numThreads = 5threads = [0] * numThreadscounter = Counter()for i in range(0, numThreads):threads[i] = Thread(target=counter.increment)for i in range(0, numThreads):threads[i].start()for i in range(0, numThreads):threads[i].join()if counter.count != 500000:print(" count = {0}".format(counter.count), flush=True)else:print(" count = 50,000 - Try re-running the program.")
Fixing Thread Unsafe Class
We fix the above example using the equivalent of a mutex in Python called a Lock
. For now, don't worry about how the example below works, but observe how the count always sums up to half a million.
from threading import Threadfrom threading import Lockimport sysclass Counter:def __init__(self):self.count = 0self.lock = Lock()def increment(self):for _ in range(100000):self.lock.acquire()self.count += 1self.lock.release()if __name__ == "__main__":# Sets the thread switch intervalsys.setswitchinterval(0.005)numThreads = 5threads = [0] * numThreadscounter = Counter()for i in range(0, numThreads):threads[i] = Thread(target=counter.increment)for i in range(0, numThreads):threads[i].start()for i in range(0, numThreads):threads[i].join()if counter.count != 500000:print(" If this line ever gets printed, " + \"the author is a complete idiot and " + \"you should return the course for a full refund!")else:print(" count = {0}".format(counter.count))