averaging.proto 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. syntax = "proto3";
  2. import "runtime.proto";
  3. // Runs alongside each trainer to perform gating function averaging every now and then. Read more: client/averaging.py
  4. service DecentralizedAveraging {
  5. rpc rpc_join_group(JoinRequest) returns (stream MessageFromLeader); // assemble a group for allreduce
  6. rpc rpc_aggregate_part(stream AveragingData) returns (stream AveragingData); // send local part => get average part
  7. rpc rpc_download_state(DownloadRequest) returns (stream DownloadData);
  8. }
  9. enum MessageCode {
  10. NO_CODE = 0; // Default value that should not be used explicitly
  11. REQUEST_JOIN = 1; // "Dear maybe leader, will you have me in your group as a follower?"
  12. ACCEPTED = 2; // "I accept you in my group, you now commit to responding to me"
  13. BEGIN_ALLREDUCE = 3; // "We can begin allreduce now. These are your peers."
  14. PART_FOR_AVERAGING = 4; // "I am running allreduce with you, here's a part of my tensor that you should aggregate"
  15. AVERAGED_PART = 5; // "I aggregated your part with others and here's the average for that part"
  16. NOT_DECLARED = 6; // "I have not declared my group id yet, how the heck did you even find me? Go away."
  17. NOT_A_LEADER = 7; // "I am not a group a leader. Go ask my leader instead."
  18. BAD_EXPIRATION_TIME = 8; // "I will not accept you. I cannot guarantee that we begin before you expire."
  19. BAD_SCHEMA_HASH = 9; // "I will not accept you. I am not averaging the samy type of tensors as you."
  20. BAD_GROUP_ID = 10; // "I will not accept your request, your group id does not match with any groups i'm in."
  21. DUPLICATE_ENDPOINT = 11; // "I will not accept you, i already have exactly the same endpoint in my current group."
  22. GROUP_IS_FULL = 12; // "I will not accept you, my group already contains too many peers."
  23. NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost."
  24. PROTOCOL_VIOLATION = 14; // "You did something so unspeakable that i don't have a special code for that."
  25. INTERNAL_ERROR = 15; // "I messed up, we will have to stop allreduce because of that."
  26. CANCELLED = 16; // "[from peer during allreduce] I no longer want to participate in AllReduce."
  27. GROUP_DISBANDED = 17; // "[from leader] The group is closed. Go find another group."
  28. }
  29. message JoinRequest {
  30. string endpoint = 1; // A follower accepts incoming allreduce requests at this address
  31. bytes schema_hash = 2; // A hash that describes follower's tensors (shapes, num tensors, etc)
  32. double expiration = 3; // Follower would like to **begin** all_reduce by this point in time
  33. bytes gather = 4; // optional metadata that is gathered from all peers (e.g. batch size or current loss)
  34. bool client_mode = 5; // if True, the incoming averager is a client with no capacity for averaging
  35. }
  36. message MessageFromLeader {
  37. MessageCode code = 1;
  38. bytes group_id = 2; // a unique identifier of this group, only valid until allreduce is finished/failed
  39. string suggested_leader = 3; // if peer is already in a group, it'll provide us with an endpoint of its leader
  40. repeated string ordered_group_endpoints = 4; // a sequence of peers, each responsible for one shard during averaging
  41. repeated bytes gathered = 5; // metadata (gather) from all groupmates in the same order as their endpoints
  42. }
  43. message AveragingData {
  44. MessageCode code = 1; // in case of a protocol violation, this will be the error message
  45. bytes group_id = 2; // a unique group identifier, same as in MessageFromLeader
  46. string endpoint = 3; // sender's rpc endpoint, used for coordination
  47. Tensor tensor_part = 4; // either peer's local tensor part (rpc input) or group average of this part (rpc output)
  48. bytes metadata = 5; // reserved user-extendable metadata
  49. }
  50. message DownloadRequest {}
  51. message DownloadData {
  52. bytes metadata = 1;
  53. Tensor tensor_part = 2;
  54. }