X-Git-Url: https://git.octo.it/?a=blobdiff_plain;f=src%2Fsn_stage.c;h=ae61578099057eee7e7a4de7da3057f5c8ea015f;hb=807591e0b3c3f9efe3914da0efa688781b0abab8;hp=d0467a4dc6cd2364c598496a239fd3049cbede25;hpb=8764b3122abba9e60cacb591f16a5e71abb5155f;p=sort-networks.git diff --git a/src/sn_stage.c b/src/sn_stage.c index d0467a4..ae61578 100644 --- a/src/sn_stage.c +++ b/src/sn_stage.c @@ -29,6 +29,7 @@ #include #include #include +#include #include "sn_comparator.h" #include "sn_stage.h" @@ -82,10 +83,17 @@ int sn_stage_comparator_add (sn_stage_t *s, const sn_comparator_t *c) sn_comparator_t *temp; int i; + if ((s == NULL) || (c == NULL)) + return (EINVAL); + + i = sn_stage_comparator_check_conflict (s, c); + if (i != 0) + return (i); + temp = (sn_comparator_t *) realloc (s->comparators, (s->comparators_num + 1) * sizeof (sn_comparator_t)); if (temp == NULL) - return (-1); + return (ENOMEM); s->comparators = temp; temp = NULL; @@ -112,6 +120,9 @@ int sn_stage_comparator_remove (sn_stage_t *s, int c_num) int nmemb = s->comparators_num - (c_num + 1); sn_comparator_t *temp; + if ((s == NULL) || (s->comparators_num <= c_num)) + return (EINVAL); + assert (c_num < s->comparators_num); assert (c_num >= 0); @@ -141,6 +152,7 @@ int sn_stage_comparator_remove (sn_stage_t *s, int c_num) sn_stage_t *sn_stage_clone (const sn_stage_t *s) { sn_stage_t *s_copy; + int i; s_copy = sn_stage_create (s->depth); if (s_copy == NULL) @@ -154,8 +166,13 @@ sn_stage_t *sn_stage_clone (const sn_stage_t *s) return (NULL); } - memcpy (s_copy->comparators, s->comparators, - s->comparators_num * sizeof (sn_comparator_t)); + for (i = 0; i < s->comparators_num; i++) + { + SN_COMP_MIN (s_copy->comparators + i) = SN_COMP_MIN (s->comparators + i); + SN_COMP_MAX (s_copy->comparators + i) = SN_COMP_MAX (s->comparators + i); + SN_COMP_USER_DATA (s_copy->comparators + i) = NULL; + SN_COMP_FREE_FUNC (s_copy->comparators + i) = NULL; + } s_copy->comparators_num = s->comparators_num; return (s_copy); @@ -276,6 +293,9 @@ int sn_stage_invert (sn_stage_t *s) { int i; + if (s == NULL) + return (EINVAL); + for (i = 0; i < s->comparators_num; i++) sn_comparator_invert (s->comparators + i); @@ -286,6 +306,13 @@ int sn_stage_shift (sn_stage_t *s, int sw, int inputs_num) { int i; + if ((s == NULL) || (inputs_num < 2)) + return (EINVAL); + + sw %= inputs_num; + if (sw == 0) + return (0); + for (i = 0; i < s->comparators_num; i++) sn_comparator_shift (s->comparators + i, sw, inputs_num); @@ -296,6 +323,9 @@ int sn_stage_swap (sn_stage_t *s, int con0, int con1) { int i; + if (s == NULL) + return (EINVAL); + for (i = 0; i < s->comparators_num; i++) sn_comparator_swap (s->comparators + i, con0, con1); @@ -307,6 +337,9 @@ int sn_stage_cut_at (sn_stage_t *s, int input, enum sn_network_cut_dir_e dir) int new_position = input; int i; + if ((s == NULL) || (input < 0)) + return (-EINVAL); + for (i = 0; i < s->comparators_num; i++) { sn_comparator_t *c = s->comparators + i; @@ -333,6 +366,47 @@ int sn_stage_cut_at (sn_stage_t *s, int input, enum sn_network_cut_dir_e dir) return (new_position); } /* int sn_stage_cut_at */ +int sn_stage_cut (sn_stage_t *s, int *mask, /* {{{ */ + sn_stage_t **prev) +{ + int i; + + if ((s == NULL) || (mask == NULL) || (prev == NULL)) + return (EINVAL); + + for (i = 0; i < s->comparators_num; i++) + { + sn_comparator_t *c = s->comparators + i; + int left = SN_COMP_LEFT (c); + int right = SN_COMP_RIGHT (c); + + if ((mask[left] == 0) + && (mask[right] == 0)) + continue; + + /* Check if we need to update the cut position */ + if ((mask[left] != mask[right]) + && ((mask[left] > 0) || (mask[right] < 0))) + { + int tmp; + int j; + + tmp = mask[right]; + mask[right] = mask[left]; + mask[left] = tmp; + + for (j = s->depth - 1; j >= 0; j--) + sn_stage_swap (prev[j], + left, right); + } + + sn_stage_comparator_remove (s, i); + i--; + } /* for (i = 0; i < s->comparators_num; i++) */ + + return (0); +} /* }}} int sn_stage_cut */ + int sn_stage_remove_input (sn_stage_t *s, int input) { int i; @@ -431,7 +505,7 @@ int sn_stage_serialize (sn_stage_t *s, #define SNPRINTF_OR_FAIL(...) \ status = snprintf (buffer, buffer_size, __VA_ARGS__); \ - if ((status < 1) || (status >= buffer_size)) \ + if ((status < 1) || (((size_t) status) >= buffer_size)) \ return (-1); \ buffer += status; \ buffer_size -= status;