# PA2 Discussion
## DSC 204a, Winter 2024

Notebook Setup: This notebook is best run on the `ray-notebook` server setup on DataHub, where all the dependencies have been installed for you.

## A Simple Introduction to Ray Actors

In [None]:
import ray
import time
import warnings
warnings.filterwarnings("ignore")
ray.init()

@ray.remote
class Counter:
    def __init__(self):
        self.value = 0

    def increment(self):
        # simulate longer execution time
        time.sleep(1)

        self.value += 1
        return self.value

    def get_counter(self):
        return self.value

# Create an actor instance from this class.
counter = Counter.remote()

In [None]:
counter.value # doesn't work directly - access actor state through methods!

In [None]:
start = time.time()
for _ in range(15):
    counter.increment.remote() # increment the counter. Note that each method is a Ray task now!
end = time.time()
print("Time :", end - start)

In [None]:
ray.get(counter.get_counter.remote()) 

Wait, shouldn't the time be 15 s? Well, these increment method calls are now Ray tasks, and are thus executed asynchronously!

All the method calls of an actor are executed in order i.e serially. Thus, the `get_counter` call will wait until all the previous `increment` calls complete execution (which is why you will see that this final `get_counter` call takes some time).

## Collective Communication with Ray

In [None]:
import ray, torch, os
import ray.util.collective as col

os.environ["PYTHONWARNINGS"]="ignore::DeprecationWarning"
@ray.remote
class Worker:
    def __init__(self, world_size, rank):
        col.init_collective_group(world_size=world_size,
                                  rank=rank,
                                  group_name="dsc204a",
                                  backend="gloo")
    
    def set_msg(self, msg):
        self.msg = msg
        return
    
    def set_buf(self, shape, dtype):
        self.buf = torch.zeros(shape, dtype=dtype)
        return
    
    def do_send(self, target_rank):
        col.send(self.msg, target_rank, group_name="dsc204a")
        return self.msg

    def do_recv(self, src_rank):
        col.recv(self.buf, src_rank, group_name="dsc204a")
        return self.buf
    
    
world_size = 2

A = Worker.remote(world_size=world_size, rank=0)
B = Worker.remote(world_size=world_size, rank=1)

col.create_collective_group(actors=[A,B],
                            world_size=world_size,
                            ranks=[0,1],
                            backend="gloo",
                            group_name="dsc204a")

msg = torch.tensor([1,2,3,4,5])

In [None]:
A.set_msg.remote(ray.put(msg))
B.set_buf.remote(msg.shape, msg.dtype)

src_msg, target_buf = ray.get([A.do_send.remote(target_rank=1), B.do_recv.remote(src_rank=0)])
print(src_msg)
print(target_buf)

In [None]:
ray.kill(A) # explicitly kill actors
ray.kill(B)

In [None]:
ray.shutdown()