Actual source code: ex4.c
  1: static const char help[] = "Tests PetscDeviceContextFork/Join.\n\n";
  3: #include "petscdevicetestcommon.h"
  5: static PetscErrorCode DoFork(PetscDeviceContext parent, PetscInt n, PetscDeviceContext **sub)
  6: {
  7:   PetscDeviceType dtype;
  8:   PetscStreamType stype;
 10:   PetscFunctionBegin;
 11:   PetscCall(PetscDeviceContextGetDeviceType(parent, &dtype));
 12:   PetscCall(PetscDeviceContextGetStreamType(parent, &stype));
 13:   PetscCall(PetscDeviceContextFork(parent, n, sub));
 14:   if (n) PetscCheck(*sub, PETSC_COMM_SELF, PETSC_ERR_PLIB, "PetscDeviceContextFork() return NULL pointer for %" PetscInt_FMT " children", n);
 15:   for (PetscInt i = 0; i < n; ++i) {
 16:     PetscDeviceType sub_dtype;
 17:     PetscStreamType sub_stype;
 19:     PetscCall(AssertDeviceContextExists((*sub)[i]));
 20:     PetscCall(PetscDeviceContextGetStreamType((*sub)[i], &sub_stype));
 21:     PetscCall(AssertPetscStreamTypesValidAndEqual(sub_stype, stype, "Child stream type %s != parent stream type %s"));
 22:     PetscCall(PetscDeviceContextGetDeviceType((*sub)[i], &sub_dtype));
 23:     PetscCall(AssertPetscDeviceTypesValidAndEqual(sub_dtype, dtype, "Child device type %s != parent device type %s"));
 24:   }
 25:   PetscFunctionReturn(PETSC_SUCCESS);
 26: }
 28: static PetscErrorCode TestNestedPetscDeviceContextForkJoin(PetscDeviceContext parCtx, PetscDeviceContext *sub)
 29: {
 30:   const PetscInt      nsub = 4;
 31:   PetscDeviceContext *subsub;
 33:   PetscFunctionBegin;
 34:   PetscAssertPointer(sub, 2);
 35:   PetscCall(AssertPetscDeviceContextsValidAndEqual(parCtx, sub[0], "Current global context does not match expected global context"));
 36:   /* create some children from an active child */
 37:   PetscCall(DoFork(sub[1], nsub, &subsub));
 38:   /* join on a sibling to the parent */
 39:   PetscCall(PetscDeviceContextJoin(sub[2], nsub - 2, PETSC_DEVICE_CONTEXT_JOIN_SYNC, &subsub));
 40:   /* join on the grandparent */
 41:   PetscCall(PetscDeviceContextJoin(parCtx, nsub - 2, PETSC_DEVICE_CONTEXT_JOIN_NO_SYNC, &subsub));
 42:   PetscCall(PetscDeviceContextJoin(sub[1], nsub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subsub));
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }
 46: /* test fork-join */
 47: static PetscErrorCode TestPetscDeviceContextForkJoin(PetscDeviceContext dctx)
 48: {
 49:   PetscDeviceContext *sub;
 50:   const PetscInt      n = 10;
 52:   PetscFunctionBegin;
 53:   /* mostly for valgrind to catch errors */
 54:   PetscCall(DoFork(dctx, n, &sub));
 55:   PetscCall(PetscDeviceContextJoin(dctx, n, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
 56:   /* do it twice */
 57:   PetscCall(DoFork(dctx, n, &sub));
 58:   PetscCall(PetscDeviceContextJoin(dctx, n, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
 60:   /* create some children */
 61:   PetscCall(DoFork(dctx, n + 1, &sub));
 62:   /* test forking within nested function */
 63:   PetscCall(TestNestedPetscDeviceContextForkJoin(sub[0], sub));
 64:   /* join a subset */
 65:   PetscCall(PetscDeviceContextJoin(dctx, n - 1, PETSC_DEVICE_CONTEXT_JOIN_NO_SYNC, &sub));
 66:   /* back to the ether from whence they came */
 67:   PetscCall(PetscDeviceContextJoin(dctx, n + 1, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
 68:   PetscFunctionReturn(PETSC_SUCCESS);
 69: }
 71: int main(int argc, char *argv[])
 72: {
 73:   MPI_Comm           comm;
 74:   PetscDeviceContext dctx;
 76:   PetscFunctionBeginUser;
 77:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
 78:   comm = PETSC_COMM_WORLD;
 80:   PetscCall(PetscDeviceContextCreate(&dctx));
 81:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)dctx, "local_"));
 82:   PetscCall(PetscDeviceContextSetFromOptions(comm, dctx));
 83:   PetscCall(TestPetscDeviceContextForkJoin(dctx));
 84:   PetscCall(PetscDeviceContextDestroy(&dctx));
 86:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
 87:   PetscCall(TestPetscDeviceContextForkJoin(dctx));
 89:   PetscCall(TestPetscDeviceContextForkJoin(NULL));
 91:   PetscCall(PetscPrintf(comm, "EXIT_SUCCESS\n"));
 92:   PetscCall(PetscFinalize());
 93:   return 0;
 94: }
 96: /*TEST
 98:   testset:
 99:     requires: cxx
100:     output_file: ./output/ExitSuccess.out
101:     nsize: {{1 3}}
102:     args: -device_enable {{lazy eager}}
103:     args: -local_device_context_stream_type {{default nonblocking default_with_barrier nonblocking_with_barrier}}
104:     test:
105:       requires: !device
106:       suffix: host_no_device
107:     test:
108:       requires: device
109:       args: -root_device_context_device_type host
110:       suffix: host_with_device
111:     test:
112:       requires: cuda
113:       args: -root_device_context_device_type cuda
114:       suffix: cuda
115:     test:
116:       requires: hip
117:       args: -root_device_context_device_type hip
118:       suffix: hip
120: TEST*/