443 {
445 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
446
447
448 NetworkScratch::FloatVec outputerr;
449 outputerr.Init(ns_, scratch);
450
451 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452 curr_stateerr.Init(ns_, scratch);
453 curr_sourceerr.Init(na_, scratch);
454 ZeroVector<double>(ns_, curr_stateerr);
455 ZeroVector<double>(na_, curr_sourceerr);
456
457 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
458 for (auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
459
460
464 stateerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
465 sourceerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
466 for (int t = 0; t < buf_width; ++t) {
467 stateerr[t].Init(ns_, scratch);
468 sourceerr[t].Init(na_, scratch);
469 ZeroVector<double>(ns_, stateerr[t]);
470 ZeroVector<double>(na_, sourceerr[t]);
471 }
472 }
473
474 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
475 for (auto & sourceerr_temp : sourceerr_temps)
476 sourceerr_temp.Init(na_, scratch);
477 int width = input_width_;
478
479 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
480 for (auto & w : gate_errors_t) {
481 w.Init(ns_, width, scratch);
482 }
483
484 NetworkScratch::FloatVec softmax_errors;
485 NetworkScratch::GradientStore softmax_errors_t;
486 if (softmax_ != nullptr) {
487 softmax_errors.Init(
no_, scratch);
488 softmax_errors_t.Init(
no_, width, scratch);
489 }
490 double state_clip =
Is2D() ? 9.0 : 4.0;
491#if DEBUG_DETAIL > 1
493 fwd_deltas.Print(10);
494#endif
495 StrideMap::Index dest_index(input_map_);
496 dest_index.InitToLast();
497
498 StrideMap::Index src_index(fwd_deltas.stride_map());
499 src_index.InitToLast();
500 do {
501 int t = dest_index.t();
502 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
503
504
505 int up_pos = -1;
506 int down_pos = -1;
509 StrideMap::Index up_index(dest_index);
510 if (up_index.AddOffset(-1,
FD_HEIGHT)) up_pos = up_index.t();
511 }
513 StrideMap::Index down_index(dest_index);
514 if (down_index.AddOffset(1,
FD_HEIGHT)) down_pos = down_index.t();
515 }
516 }
517
518 int mod_t =
Modulo(t, buf_width);
519
520 if (at_last_x) {
521 ZeroVector<double>(na_, curr_sourceerr);
522 ZeroVector<double>(ns_, curr_stateerr);
523 }
524
527 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528 src_index.Decrement();
529 } else {
530 ZeroVector<double>(ns_, outputerr);
531 }
532 } else if (softmax_ == nullptr) {
533 fwd_deltas.ReadTimeStep(t, outputerr);
534 } else {
536 softmax_errors_t.get(), outputerr);
537 }
538 if (!at_last_x)
540 if (down_pos >= 0)
542
543 if (!at_last_x) {
544 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
545 for (int i = 0; i < ns_; ++i) {
546 curr_stateerr[i] *= next_node_gf1[i];
547 }
548 }
549 if (
Is2D() && t + 1 < width) {
550 for (int i = 0; i < ns_; ++i) {
551 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
552 }
553 if (down_pos >= 0) {
554 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
555 const double* right_stateerr = stateerr[mod_t];
556 for (int i = 0; i < ns_; ++i) {
557 if (which_fg_[down_pos][i] == 2) {
558 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
559 }
560 }
561 }
562 }
564 curr_stateerr);
565
566 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567#if DEBUG_DETAIL > 1
568 if (t + 10 > width) {
570 for (int i = 0; i < ns_; ++i)
571 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
572 curr_sourceerr[
ni_ + nf_ + i]);
574 }
575#endif
576
578
579
580 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t,
581 curr_stateerr, gate_errors[
CI]);
584 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[
CI]);
585
587
588 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t,
589 curr_stateerr, gate_errors[
GI]);
592 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[
GI]);
593
595
596 if (t > 0) {
597 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
601 sourceerr_temps[
GF1]);
602 } else {
603 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
604 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
605 }
606 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[
GF1]);
607
608
609 if (up_pos >= 0) {
610 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
614 sourceerr_temps[
GFS]);
615 } else {
616 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
617 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
618 }
619 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[
GFS]);
620
622
623 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr,
627 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[
GO]);
629
631 sourceerr_temps[
GF1], sourceerr_temps[
GO], sourceerr_temps[
GFS],
632 curr_sourceerr);
633 back_deltas->WriteTimeStep(t, curr_sourceerr);
634
636 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638 }
639 } while (dest_index.Decrement());
640#if DEBUG_DETAIL > 2
641 for (
int w = 0; w <
WT_COUNT; ++w) {
643 gate_errors_t[w].get()->PrintUnTransposed(10);
644 }
645#endif
646
647 NetworkScratch::GradientStore source_t, state_t;
648 source_t.Init(na_, width, scratch);
650 state_t.Init(ns_, width, scratch);
651 state_.Transpose(state_t.get());
652#ifdef _OPENMP
653#pragma omp parallel for num_threads(GFS) if (!Is2D())
654#endif
655 for (
int w = 0; w <
WT_COUNT; ++w) {
656 if (w ==
GFS && !
Is2D())
continue;
658 }
659 if (softmax_ != nullptr) {
661 }
663}
#define END_PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
#define SECTION_IF_OPENMP
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
void AccumulateVector(int n, const double *src, double *dest)
void CopyVector(int n, const double *src, double *dest)
void ClipVector(int n, T lower, T upper, T *vec)
void init_to_size(int size, const T &t)
const char * string() const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void FinishBackward(const TransposedArray &errors_t)
void DisplayBackward(const NetworkIO &matrix)
void Transpose(TransposedArray *dest) const
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
int Size(FlexDimensions dimension) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void VectorDotMatrix(const double *u, double *v) const