tesseract 4.1.1
Loading...
Searching...
No Matches
intsimdmatrixavx2.cpp
Go to the documentation of this file.
1
2// File: intsimdmatrixavx2.cpp
3// Description: matrix-vector product for 8-bit data on avx2.
4// Author: Ray Smith
5// Created: Fri Aug 04 13:26:20 PST 2017
6//
7// (C) Copyright 2017, Google Inc.
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License at
11// http://www.apache.org/licenses/LICENSE-2.0
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
18
19#if !defined(__AVX2__)
20#error Implementation only for AVX2 capable architectures
21#endif
22
23#include "intsimdmatrix.h"
24
25#include <immintrin.h>
26#include <cstdint>
27#include <algorithm>
28#include <vector>
29
30namespace tesseract {
31
32// Number of outputs held in each register. 8 x 32 bit ints.
33constexpr int kNumOutputsPerRegister = 8;
34// Maximum number of registers that we will use.
35constexpr int kMaxOutputRegisters = 8;
36// Number of inputs in the inputs register.
37constexpr int kNumInputsPerRegister = 32;
38// Number of inputs in each weight group.
39constexpr int kNumInputsPerGroup = 4;
40// Number of groups of inputs to be broadcast.
42
43// Functions to compute part of a matrix.vector multiplication. The weights
44// are in a very specific order (see above) in w, which is multiplied by
45// u of length num_in, to produce output v after scaling the integer results
46// by the corresponding member of scales.
47// The amount of w and scales consumed is fixed and not available to the
48// caller. The number of outputs written to v will be at most num_out.
49
50// Computes one set of 4x8 products of inputs and weights, adding to result.
51// Horizontally adds 4 adjacent results, making 8x32-bit results.
52// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
53// Note that wi must previously have been re-organized with blocks of 4x8
54// weights in contiguous memory.
55// ones is a register of 16x16-bit values all equal to 1.
56// Note: wi is incremented by the amount of data read.
57// weights and reps are scratch registers.
58// This function must be inlined with references in order for the compiler to
59// correctly use the registers declared in the caller.
60static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
61 const int8_t*& wi, __m256i& weights,
62 __m256i& reps, __m256i& result) {
63 // Load a 4x8 block of weights.
64 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
66 // Normalize the signs on rep_input, weights, so weights is always +ve.
67 reps = _mm256_sign_epi8(rep_input, weights);
68 weights = _mm256_sign_epi8(weights, weights);
69 // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
70 // with adjacent pairs added.
71 weights = _mm256_maddubs_epi16(weights, reps);
72 // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
73 // with adjacent pairs added. What we really want is a horizontal add of
74 // 16+16=32 bit result, but there is no such instruction, so multiply by
75 // 16-bit ones instead. It is probably faster than all the sign-extending,
76 // permuting and adding that would otherwise be required.
77 weights = _mm256_madd_epi16(weights, ones);
78 result = _mm256_add_epi32(result, weights);
79}
80
81// Extracts and converts 8x32-bit results from result, adding the bias from wi
82// and scaling by scales, before storing in *v. Note that wi, scales and v are
83// expected to contain 8 consecutive elements or num_out if less.
84static inline void ExtractResults(__m256i& result, __m256i& shift_id,
85 const int8_t*& wi, const double*& scales,
86 int num_out, double*& v) {
87 for (int out = 0; out < num_out; ++out) {
88#ifndef _MSC_VER
89 auto res = _mm256_extract_epi32(result, 0);
90#else
91 // Workaround MSVC's ICE
92 // _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
93 auto res = ((int32_t*)&result)[0];
94#endif
95 *v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
96 // Rotate the results in int32_t units, so the next result is ready.
97 result = _mm256_permutevar8x32_epi32(result, shift_id);
98 }
99}
100
101// Computes part of matrix.vector v = Wu. Computes N=64 results.
102// The weights *must* be arranged so that consecutive reads from wi
103// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
104// (kNumInputsPerGroup inputs))). After that there must be N consecutive
105// bias weights, before continuing with any more weights.
106// u must be padded out with zeros to
107// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
108static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
109 const int8_t* u, int num_in, int num_out,
110 double* v) {
111 // Register containing 16-bit ones for horizontal add with 16->32 bit
112 // conversion.
113 __m256i ones =
114 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
115 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
116 // Initialize all the results to 0.
117 __m256i result0 = _mm256_setzero_si256();
118 __m256i result1 = _mm256_setzero_si256();
119 __m256i result2 = _mm256_setzero_si256();
120 __m256i result3 = _mm256_setzero_si256();
121 __m256i result4 = _mm256_setzero_si256();
122 __m256i result5 = _mm256_setzero_si256();
123 __m256i result6 = _mm256_setzero_si256();
124 __m256i result7 = _mm256_setzero_si256();
125 // Iterate over the input (u), one registerful at a time.
126 for (int j = 0; j < num_in;) {
127 __m256i inputs =
128 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
129 // Inputs are processed in groups of kNumInputsPerGroup, replicated
130 // kNumInputGroups times.
131 for (int ig = 0; ig < kNumInputGroups && j < num_in;
132 ++ig, j += kNumInputsPerGroup) {
133 // Replicate the low 32 bits (4 inputs) 8 times.
134 __m256i rep_input =
135 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
136 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
137 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
138 __m256i weights, reps;
139 // Mul-add, with horizontal add of the 4 inputs to each of the results.
140 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
141 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
142 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
143 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
144 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
145 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
146 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
147 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
148 }
149 }
150 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
151 ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
152 ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
153 ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
154 ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
155 ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
156 ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
157 num_out -= kNumOutputsPerRegister * 7;
158 ExtractResults(result7, shift_id, wi, scales,
159 std::min(kNumOutputsPerRegister, num_out), v);
160}
161
162// Computes part of matrix.vector v = Wu. Computes N=32 results.
163// For details see PartialMatrixDotVector64 with N=32.
164static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
165 const int8_t* u, int num_in, int num_out,
166 double* v) {
167 // Register containing 16-bit ones for horizontal add with 16->32 bit
168 // conversion.
169 __m256i ones =
170 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
171 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
172 // Initialize all the results to 0.
173 __m256i result0 = _mm256_setzero_si256();
174 __m256i result1 = _mm256_setzero_si256();
175 __m256i result2 = _mm256_setzero_si256();
176 __m256i result3 = _mm256_setzero_si256();
177 // Iterate over the input (u), one registerful at a time.
178 for (int j = 0; j < num_in;) {
179 __m256i inputs =
180 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
181 // Inputs are processed in groups of kNumInputsPerGroup, replicated
182 // kNumInputGroups times.
183 for (int ig = 0; ig < kNumInputGroups && j < num_in;
184 ++ig, j += kNumInputsPerGroup) {
185 // Replicate the low 32 bits (4 inputs) 8 times.
186 __m256i rep_input =
187 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
188 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
189 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
190 __m256i weights, reps;
191 // Mul-add, with horizontal add of the 4 inputs to each of the results.
192 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
193 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
194 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
195 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
196 }
197 }
198 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
199 ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
200 ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
201 num_out -= kNumOutputsPerRegister * 3;
202 ExtractResults(result3, shift_id, wi, scales,
203 std::min(kNumOutputsPerRegister, num_out), v);
204}
205
206// Computes part of matrix.vector v = Wu. Computes N=16 results.
207// For details see PartialMatrixDotVector64 with N=16.
208static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
209 const int8_t* u, int num_in, int num_out,
210 double* v) {
211 // Register containing 16-bit ones for horizontal add with 16->32 bit
212 // conversion.
213 __m256i ones =
214 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
215 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
216 // Initialize all the results to 0.
217 __m256i result0 = _mm256_setzero_si256();
218 __m256i result1 = _mm256_setzero_si256();
219 // Iterate over the input (u), one registerful at a time.
220 for (int j = 0; j < num_in;) {
221 __m256i inputs =
222 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
223 // Inputs are processed in groups of kNumInputsPerGroup, replicated
224 // kNumInputGroups times.
225 for (int ig = 0; ig < kNumInputGroups && j < num_in;
226 ++ig, j += kNumInputsPerGroup) {
227 // Replicate the low 32 bits (4 inputs) 8 times.
228 __m256i rep_input =
229 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
230 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
231 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
232 __m256i weights, reps;
233 // Mul-add, with horizontal add of the 4 inputs to each of the results.
234 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
235 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
236 }
237 }
238 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
239 num_out -= kNumOutputsPerRegister;
240 ExtractResults(result1, shift_id, wi, scales,
241 std::min(kNumOutputsPerRegister, num_out), v);
242}
243
244// Computes part of matrix.vector v = Wu. Computes N=8 results.
245// For details see PartialMatrixDotVector64 with N=8.
246static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
247 const int8_t* u, int num_in, int num_out,
248 double* v) {
249 // Register containing 16-bit ones for horizontal add with 16->32 bit
250 // conversion.
251 __m256i ones =
252 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
253 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
254 // Initialize all the results to 0.
255 __m256i result0 = _mm256_setzero_si256();
256 // Iterate over the input (u), one registerful at a time.
257 for (int j = 0; j < num_in;) {
258 __m256i inputs =
259 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
260 // Inputs are processed in groups of kNumInputsPerGroup, replicated
261 // kNumInputGroups times.
262 for (int ig = 0; ig < kNumInputGroups && j < num_in;
263 ++ig, j += kNumInputsPerGroup) {
264 // Replicate the low 32 bits (4 inputs) 8 times.
265 __m256i rep_input =
266 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
267 // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
268 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
269 __m256i weights, reps;
270 // Mul-add, with horizontal add of the 4 inputs to each of the results.
271 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
272 }
273 }
274 ExtractResults(result0, shift_id, wi, scales, num_out, v);
275}
276
277static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
278 const double* scales, const int8_t* u, double* v) {
279 const int num_out = dim1;
280 const int num_in = dim2 - 1;
281 // Each call to a partial_func_ produces group_size outputs, except the
282 // last one, which can produce less.
283 const int rounded_num_in =
285 const int rounded_num_out =
288 int output = 0;
289
290 int w_step = (rounded_num_in + 1) * group_size;
291
292 // Run with this group size, until it would produce too much output, then
293 // switch to a smaller size.
294 for (; output + group_size <= rounded_num_out; output += group_size) {
295 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
296 wi += w_step;
297 scales += group_size;
298 v += group_size;
299 }
300 group_size /= 2;
301 w_step /= 2;
302
303 for (; output + group_size <= rounded_num_out; output += group_size) {
304 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
305 wi += w_step;
306 scales += group_size;
307 v += group_size;
308 }
309 group_size /= 2;
310 w_step /= 2;
311
312 for (; output + group_size <= rounded_num_out; output += group_size) {
313 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
314 wi += w_step;
315 scales += group_size;
316 v += group_size;
317 }
318 group_size /= 2;
319 w_step /= 2;
320
321 for (; output + group_size <= rounded_num_out; output += group_size) {
322 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
323 wi += w_step;
324 scales += group_size;
325 v += group_size;
326 }
327}
328
329const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
330 // Function.
331 matrixDotVector,
332 // Number of 32 bit outputs held in each register.
334 // Maximum number of registers that we will use to hold outputs.
336 // Number of 8 bit inputs in the inputs register.
338 // Number of inputs in each weight group.
340};
341
342} // namespace tesseract.
constexpr int kMaxOutputRegisters
constexpr int kNumInputsPerRegister
constexpr int kNumInputGroups
constexpr int kNumOutputsPerRegister
constexpr int kNumInputsPerGroup
static const IntSimdMatrix intSimdMatrixAVX2
static int Roundup(int input, int factor)
Definition: intsimdmatrix.h:87