Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 05d1b26

Browse files
KexinFengbarry-jin
andauthored
[FEATURE] Add feature of retain_grad (#20500)
* Replace "CloneGradient" with "ElemwiseGradUseNone" * fix issue elemwise_add * fix elemwise_add issue with `ElemwiseGradUseNone` * reverse_to_CloneGradient * add_retain_grad * unit_test * tidy_up * tidy_up * sanity * const_reference * const_ref * merge_rg_to_ag * sanity * sanity * add_drop_grad * sanity_check * sanity_check * sanity_check * build_err * build_err * skip_remark_variable * repetitive_mark * ReInit_in_dropgrad * ReInit_in_dropgrad * sanity_check * add drop and tests to gluon * sanity * update exec_pass.h Co-authored-by: Zhenghui Jin <[email protected]>
1 parent c1e06aa commit 05d1b26

File tree

9 files changed

+211
-25
lines changed

9 files changed

+211
-25
lines changed

include/mxnet/c_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,6 +1274,14 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
12741274
NDArrayHandle *var_handles,
12751275
uint32_t *reqs_array,
12761276
NDArrayHandle *grad_handles);
1277+
/*!
1278+
* \brief unmark nonleaf NDArrays to free the memory
1279+
* \param num_var number of variable NDArrays
1280+
* \param var_handles variable NDArrays
1281+
* \return 0 when success, -1 when failure happens
1282+
*/
1283+
MXNET_DLL int MXAutogradDropGrads(uint32_t num_var,
1284+
NDArrayHandle *var_handles);
12771285
/*!
12781286
* \brief compute the gradient of outputs w.r.t variabels
12791287
* \param num_output number of output NDArray

include/mxnet/imperative.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,16 @@ class Imperative {
272272
void MarkVariables(const std::vector<NDArray*>& variables,
273273
const std::vector<uint32_t>& grad_reqs,
274274
const std::vector<NDArray*>& gradients);
275+
/*! \brief unmark nonleaf variables to free the memory. */
276+
void DropGrads(const std::vector<NDArray*>& variables);
275277
/*! \brief compute the gradient of outputs w.r.t variables. */
276278
std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
277279
const std::vector<NDArray*>& ograds,
278280
const std::vector<NDArray*>& variables,
279281
bool is_train, bool retain_graph,
280282
bool create_graph);
283+
/*! \brief Return the marked nonleaf nodes. */
284+
std::vector<nnvm::ObjectPtr> ListNonleafVariables(const nnvm::Symbol& sym) const;
281285
/*! \return AutogradRuntime singleton */
282286
static Imperative* Get();
283287
/*! \brief Should op execution bulking be employed during inference. */

python/mxnet/ndarray/ndarray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,6 +2885,11 @@ def attach_grad(self, grad_req='write', stype=None):
28852885
ctypes.pointer(mx_uint(grad_req)),
28862886
ctypes.pointer(grad.handle)))
28872887

2888+
def drop_grad(self):
2889+
"""Free the memory of the marked ndarray."""
2890+
check_call(_LIB.MXAutogradDropGrads(
2891+
1, ctypes.pointer(self.handle)))
2892+
28882893
@property
28892894
def grad(self):
28902895
"""Returns gradient buffer attached to this NDArray."""

python/mxnet/numpy/multiarray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,11 @@ def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ
14101410
ctypes.pointer(mx_uint(grad_req)),
14111411
ctypes.pointer(grad.handle)))
14121412

1413+
def drop_grad(self):
1414+
"""Free the memory of the marked ndarray."""
1415+
check_call(_LIB.MXAutogradDropGrads(
1416+
1, ctypes.pointer(self.handle)))
1417+
14131418
@property
14141419
def grad(self):
14151420
"""Returns gradient buffer attached to this ndarray."""

src/c_api/c_api_ndarray.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,18 @@ int MXAutogradMarkVariables(uint32_t num_var,
335335
API_END();
336336
}
337337

338+
int MXAutogradDropGrads(uint32_t num_var,
339+
NDArrayHandle *var_handles) {
340+
API_BEGIN();
341+
std::vector<NDArray*> variables;
342+
variables.reserve(num_var);
343+
for (uint32_t i = 0; i < num_var; ++i) {
344+
variables.emplace_back(static_cast<NDArray*>(var_handles[i]));
345+
}
346+
Imperative::Get()->DropGrads(variables);
347+
API_END();
348+
}
349+
338350
int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) {
339351
return MXAutogradBackward(num_output, output_handles, nullptr, 0);
340352
}

src/imperative/exec_pass.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,14 @@ inline Graph MXGradient(
287287
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
288288
std::string copy_op_str = std::string(),
289289
mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
290-
DTypeVector in_arg_dtypes = DTypeVector()) {
290+
DTypeVector in_arg_dtypes = DTypeVector(),
291+
std::vector<NodeEntry> us = std::vector<NodeEntry>() ) {
291292
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
292293
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
293294
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
294295
graph.attrs["in_arg_shapes"] = std::make_shared<any>(std::move(in_arg_shapes));
295296
graph.attrs["in_arg_dtypes"] = std::make_shared<any>(std::move(in_arg_dtypes));
297+
graph.attrs["grad_us"] = std::make_shared<any>(std::move(us));
296298

297299
if (aggregate_fun != nullptr) {
298300
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);

src/imperative/imperative.cc

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,29 +142,54 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
142142
const std::vector<uint32_t>& grad_reqs,
143143
const std::vector<NDArray*>& gradients) {
144144
for (uint32_t i = 0; i < variables.size(); ++i) {
145-
std::string str_c(std::to_string(variable_count_++));
146-
147-
variables[i]->autograd_entry_ =
148-
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
149-
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
150-
info.outputs.emplace_back(variables[i]->Detach());
151-
info.out_grads.emplace_back(gradients[i]->Detach());
152-
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
153-
info.ctx = variables[i]->ctx();
154-
155-
gradients[i]->autograd_entry_ =
156-
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
157-
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
158-
grad_info.outputs.emplace_back(gradients[i]->Detach());
159-
grad_info.ctx = gradients[i]->ctx();
145+
// Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't.
146+
if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) {
147+
std::string str_c(std::to_string(variable_count_++));
148+
variables[i]->autograd_entry_ =
149+
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
150+
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
151+
info.outputs.emplace_back(variables[i]->Detach());
152+
info.out_grads.emplace_back(gradients[i]->Detach());
153+
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
154+
info.ctx = variables[i]->ctx();
155+
156+
gradients[i]->autograd_entry_ =
157+
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
158+
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
159+
grad_info.outputs.emplace_back(gradients[i]->Detach());
160+
grad_info.ctx = gradients[i]->ctx();
161+
} else {
162+
AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node);
163+
CHECK_EQ(info.out_grads.size(), 0)
164+
<<"The node has already been marked. Cannot mark it again.";
165+
info.out_grads.emplace_back(gradients[i]->Detach());
166+
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
167+
info.ctx = variables[i]->ctx();
168+
}
169+
}
170+
}
171+
172+
// Unmark the variables to free the memory.
173+
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
174+
for (auto variable : variables) {
175+
if (variable->autograd_entry_.node) {
176+
AGInfo& info = AGInfo::Get(variable->autograd_entry_.node);
177+
CHECK_NE(info.out_grads.size(), 0)
178+
<<"The node has empty out_grads already. Cannot DropGrads again.";
179+
for (auto grad : info.out_grads) {
180+
grad.ReInit();
181+
}
182+
info.out_grads.clear();
183+
info.grad_req = kNullOp;
184+
}
160185
}
161186
}
162187

163188
void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node,
164189
uint32_t num_inputs,
165190
uint32_t num_outputs,
166-
std::vector<bool>* p_save_inputs,
167-
std::vector<bool>* p_save_outputs) {
191+
std::vector<bool> *p_save_inputs,
192+
std::vector<bool> *p_save_outputs) {
168193
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
169194
std::vector<bool>& save_inputs = *p_save_inputs;
170195
std::vector<bool>& save_outputs = *p_save_outputs;
@@ -488,6 +513,12 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
488513
}
489514
CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
490515
}
516+
std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
517+
std::vector<NodeEntry> us;
518+
us.reserve(nleaf_vars.size());
519+
for (const auto& i : nleaf_vars) {
520+
us.emplace_back(NodeEntry{i, 0, 0});
521+
}
491522

492523
Graph g_graph = pass::MXGradient(graph,
493524
graph.outputs,
@@ -496,7 +527,10 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
496527
mxnet::AggregateGradient,
497528
nullptr,
498529
zero_ops,
499-
"_copy");
530+
"_copy",
531+
ShapeVector(),
532+
DTypeVector(),
533+
us);
500534
CHECK_EQ(g_graph.outputs.size(), xs.size());
501535
for (const auto& e : g_graph.outputs) {
502536
if (e.node->op() == nullptr) {
@@ -575,6 +609,20 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
575609
arrays[eid] = x_grads[i - num_forward_outputs];
576610
ref_count[eid] = 1;
577611
}
612+
const std::vector<NodeEntry>& us_grads =
613+
g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
614+
CHECK_EQ(us_grads.size(), us.size())
615+
<< "Size of queried nleaf_vars and size of their gradients don't match.";
616+
for (size_t i = 0; i < us_grads.size(); i++) {
617+
size_t eid = idx.entry_id(us_grads[i]);
618+
AGInfo& info = AGInfo::Get(us[i].node);
619+
if (arrays[eid]->dtype_ == -1) {
620+
arrays[eid] = &info.out_grads[0];
621+
} else {
622+
info.out_grads[0] = *arrays[eid];
623+
}
624+
ref_count[eid] = 1;
625+
}
578626

579627
// Assign context
580628
auto vctx = PlaceDevice(idx);
@@ -627,6 +675,11 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
627675
size_t eid = idx.entry_id(idx.outputs()[i]);
628676
array_reqs[eid] = x_reqs[i - num_forward_outputs];
629677
}
678+
for (size_t i = 0; i < us_grads.size(); i++) {
679+
size_t eid = idx.entry_id(us_grads[i]);
680+
AGInfo& info = AGInfo::Get(us[i].node);
681+
array_reqs[eid] = info.grad_req;
682+
}
630683

631684
const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape");
632685
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
@@ -766,4 +819,16 @@ void Imperative::DCInfo::Compute(const NDArray& arr) {
766819
info.outputs_.clear();
767820
}
768821

822+
std::vector<nnvm::ObjectPtr> Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const {
823+
using namespace nnvm;
824+
std::vector<ObjectPtr> ret;
825+
DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) {
826+
AGInfo& info = AGInfo::Get(node);
827+
if (info.out_grads.size() > 0 && !node->is_variable()) {
828+
ret.push_back(node);
829+
}
830+
});
831+
return ret;
832+
}
833+
769834
} // namespace mxnet

src/nnvm/gradient.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ Graph BuildGradientGraph(const Graph& src,
6262
const std::vector<ObjectPtr>& topo_order,
6363
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
6464
std::function<int(const Node&)> mirror_fun,
65-
const std::unordered_map<const Node*, ObjectPtr>& mirror_map);
65+
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
66+
const std::vector<NodeEntry>& us = std::vector<NodeEntry>());
6667

6768
/*!
6869
* \brief Auxiliary function that maps the forward node of the source graph to
@@ -88,6 +89,8 @@ Graph Gradient(Graph src) {
8889
const std::vector<NodeEntry>& ys_out_grad =
8990
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
9091
CHECK_EQ(ys.size(), ys_out_grad.size());
92+
const std::vector<NodeEntry>& us =
93+
src.GetAttr<std::vector<NodeEntry> >("grad_us");
9194

9295
// initialize a topological order of the graph nodes and `output_grads`
9396
// that maps every operator node to its gradient entries
@@ -120,7 +123,7 @@ Graph Gradient(Graph src) {
120123
std::unordered_map<const Node*, ObjectPtr> mirror_map;
121124

122125
// complete the backward graph of the src, but without backward mirroring
123-
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map);
126+
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map, us);
124127
if (mirror_fun == nullptr) {
125128
return gsrc; // Gradient pass without mirroring ends here.
126129
}
@@ -504,12 +507,14 @@ inline bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
504507
return true;
505508
}
506509

510+
507511
Graph BuildGradientGraph(const Graph& src,
508512
const std::vector<NodeEntry>& xs,
509513
const std::vector<ObjectPtr>& topo_order,
510514
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
511515
std::function<int(const Node&)> mirror_fun,
512-
const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
516+
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
517+
const std::vector<NodeEntry>& us) {
513518
static auto& grad_fun_map = Op::GetAttr<nnvm::FGradient>("FGradient");
514519

515520
// gradient aggregation function
@@ -617,7 +622,7 @@ Graph BuildGradientGraph(const Graph& src,
617622
CHECK(src_fwd_node->inputs.size() <= input_grads.size());
618623
for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end();
619624
++input_iter, ++input_grad_iter) {
620-
// propagate the input gradients to the output gradients of the input nodes
625+
// propagate the input_grads to the corresponding GradEntries mapped by output_grads
621626
output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back(
622627
std::move(*input_grad_iter));
623628
}
@@ -661,6 +666,20 @@ Graph BuildGradientGraph(const Graph& src,
661666
ret.outputs[kv.second.second] = kv.first;
662667
}
663668
}
669+
670+
// Take the us' grad NodeEntry and store them in graph.attrs
671+
std::vector<NodeEntry> nleaf_grads;
672+
nleaf_grads.reserve(us.size());
673+
for (const NodeEntry& e : us) {
674+
GradEntry& entry = output_grads[e.node.get()][e.index];
675+
// aggregate sum if it hasn't been
676+
if (entry.sum.node.get() == nullptr) {
677+
entry.sum = agg_fun(std::move(entry.grads));
678+
}
679+
nleaf_grads.push_back(entry.sum);
680+
}
681+
ret.attrs["nleaf_grads"] = std::make_shared<any>(std::move(nleaf_grads));
682+
664683
return ret;
665684
}
666685

@@ -673,7 +692,8 @@ NNVM_REGISTER_PASS(MXGradient)
673692
.depend_graph_attr("grad_xs")
674693
.depend_graph_attr("in_arg_shapes")
675694
.depend_graph_attr("in_arg_dtypes")
676-
.depend_graph_attr("grad_ys_out_grad");
695+
.depend_graph_attr("grad_ys_out_grad")
696+
.depend_graph_attr("grad_us");
677697

678698
} // namespace
679699

tests/python/unittest/test_autograd.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def test_detach_updated_grad():
243243
assert x._fresh_grad == False
244244

245245

246-
def test_retain_grad():
246+
def test_retain_graph():
247247
x = mx.nd.ones((2, 2))
248248
dx = mx.nd.zeros((2, 2))
249249
mark_variables([x], [dx], grad_reqs='add')
@@ -519,3 +519,68 @@ def test_gradient():
519519
dx.backward()
520520
assert abs(x.grad.asscalar() - 2.71828175) < 1e-7
521521

522+
def test_retain_grad_drop_grad():
523+
x = nd.array([1,2,3,4])
524+
x.attach_grad()
525+
y = nd.array([5,6,7,8])
526+
y.attach_grad()
527+
528+
with mx.autograd.record():
529+
u = x * y
530+
z = u * x
531+
532+
u.attach_grad()
533+
z.attach_grad()
534+
out_grad = nd.array([10, 10, 10, 10])
535+
z.backward(out_grad, retain_graph=True)
536+
537+
assert (u.grad == out_grad * x).asnumpy().all()
538+
assert (z.grad == out_grad).asnumpy().all()
539+
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
540+
assert (y.grad == out_grad * x*x).asnumpy().all()
541+
542+
u.drop_grad()
543+
z.drop_grad()
544+
y.drop_grad()
545+
out_grad = nd.array([0.1, 0.1, 0.1, 0.1])
546+
z.backward(out_grad)
547+
548+
assert u.grad is None and z.grad is None and y.grad is None
549+
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
550+
551+
def test_retain_grad_drop_grad_gluon():
552+
class CompBlock(mx.gluon.HybridBlock):
553+
def __init__(self):
554+
super().__init__()
555+
self.marked_var = None
556+
def forward(self, a, b):
557+
out1 = a*b
558+
out2 = out1 * a
559+
self.marked_var = out1
560+
return out2
561+
x = mx.np.array([1,2,3,4])
562+
y = mx.np.array([5,6,7,8])
563+
x.attach_grad()
564+
y.attach_grad()
565+
block2 = CompBlock()
566+
block2.initialize()
567+
# block2.hybridize()
568+
with mx.autograd.record():
569+
z = block2(x, y)
570+
u = block2.marked_var
571+
u.attach_grad()
572+
z.attach_grad()
573+
z.backward(retain_graph=True)
574+
575+
assert (u.grad == x).all()
576+
assert (z.grad == mx.np.array([1,1,1,1])).all()
577+
assert (x.grad == 2 * x * y).all()
578+
assert (y.grad == x*x).all()
579+
580+
u.drop_grad()
581+
z.drop_grad()
582+
y.drop_grad()
583+
z.backward()
584+
585+
assert u.grad is None and z.grad is None and y.grad is None
586+
assert (x.grad == 2 * x * y).all()

0 commit comments

Comments
 (0)