sn-bb, sn-bb-merge: Add branch and bound algorithms for searching for sort and merge...
[sort-networks.git] / src / sn-bb.c
1 /**
2  * libsortnetwork - src/sn-bb-merge.c
3  * Copyright (C) 2008-2010  Florian octo Forster
4  *
5  * This program is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License as published by the
7  * Free Software Foundation; only version 2 of the License is applicable.
8  *
9  * This program is distributed in the hope that it will be useful, but
10  * WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License along
15  * with this program; if not, write to the Free Software Foundation, Inc.,
16  * 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
17  *
18  * Authors:
19  *   Florian octo Forster <ff at octo.it>
20  **/
21
22 #ifndef _ISOC99_SOURCE
23 # define _ISOC99_SOURCE
24 #endif
25 #ifndef _POSIX_C_SOURCE
26 # define _POSIX_C_SOURCE 200809L
27 #endif
28 #ifndef _XOPEN_SOURCE
29 # define _XOPEN_SOURCE 700
30 #endif
31
32 #include <stdlib.h>
33 #include <stdio.h>
34 #include <stdint.h>
35 #include <string.h>
36
37 #include <sys/types.h>
38 #include <sys/stat.h>
39 #include <unistd.h>
40 #include <assert.h>
41 #include <limits.h>
42
43 #include <math.h>
44
45 #include <pthread.h>
46
47 #include "sn_network.h"
48 #include "sn_random.h"
49
50 #if !defined(__GNUC__) || !__GNUC__
51 # define __attribute__(x) /**/
52 #endif
53
54 #if BUILD_MERGE
55 static int inputs_left = 0;
56 static int inputs_right = 0;
57 #else
58 static int inputs_num = 0;
59 #endif
60 static int comparators_num = -1;
61 static int max_depth = -1;
62 static int max_stages = INT_MAX;
63
64 static char *initial_input_file = NULL;
65
66 static void exit_usage (const char *name) /* {{{ */
67 {
68   printf ("Usage: %s [options]\n"
69       "\n"
70       "Valid options are:\n"
71 #if BUILD_MERGE
72       "  -i <inputs>[:<inputs>]    Number of inputs (left and right)\n"
73 #else
74       "  -i <inputs>               Number of inputs\n"
75 #endif
76       "  -c <comparators>          Number of comparators\n"
77       "  -s <stages>               Maximum number of stages\n"
78       "  -I <file>                 Initial input file\n"
79       "\n",
80       name);
81   exit (1);
82 } /* }}} void exit_usage */
83
84 int read_options (int argc, char **argv) /* {{{ */
85 {
86   int option;
87
88   while ((option = getopt (argc, argv, "i:c:I:o:p:P:s:t:h")) != -1)
89   {
90     switch (option)
91     {
92 #if BUILD_MERGE
93       case 'i':
94       {
95         int tmp_left;
96         int tmp_right;
97
98         char *tmp_str;
99
100         tmp_str = strchr (optarg, ':');
101         if (tmp_str != NULL)
102         {
103           *tmp_str = 0;
104           tmp_str++;
105           tmp_right = atoi (tmp_str);
106         }
107         else
108         {
109           tmp_right = 0;
110         }
111
112         tmp_left = atoi (optarg);
113
114         if (tmp_left <= 0)
115           exit_usage (argv[0]);
116
117         if (tmp_right <= 0)
118           tmp_right = tmp_left;
119
120         inputs_left = tmp_left;
121         inputs_right = tmp_right;
122
123         break;
124       }
125 #else
126       case 'i':
127       {
128         int tmp;
129         tmp = atoi (optarg);
130         if (tmp > 0)
131           inputs_num = tmp;
132         break;
133       }
134 #endif
135
136       case 'c':
137       {
138         int tmp;
139         tmp = atoi (optarg);
140         if (tmp > 0)
141           comparators_num = tmp;
142         break;
143       }
144
145       case 'I':
146       {
147         if (initial_input_file != NULL)
148           free (initial_input_file);
149         initial_input_file = strdup (optarg);
150         break;
151       }
152
153       case 's':
154       {
155         int tmp;
156         tmp = atoi (optarg);
157         if (tmp > 0)
158           max_stages = tmp;
159         break;
160       }
161
162       case 'h':
163       default:
164         exit_usage (argv[0]);
165     }
166   }
167
168   return (0);
169 } /* }}} int read_options */
170
171 #if BUILD_MERGE
172 static int rate_network (sn_network_t *n) /* {{{ */
173 {
174   int test_pattern[n->inputs_num];
175   int values[n->inputs_num];
176
177   int patterns_failed = 0;
178
179   int zeros_left;
180   int zeros_right;
181
182   int i;
183
184   assert (n->inputs_num == (inputs_left + inputs_right));
185
186   memset (test_pattern, 0, sizeof (test_pattern));
187   for (i = 0; i < inputs_left; i++)
188     test_pattern[i] = 1;
189
190   for (zeros_left = 0; zeros_left <= inputs_left; zeros_left++)
191   {
192     int status;
193     int previous;
194
195     if (zeros_left > 0)
196       test_pattern[zeros_left - 1] = 0;
197
198     for (i = 0; i < inputs_right; i++)
199       test_pattern[inputs_left + i] = 1;
200
201     for (zeros_right = 0; zeros_right <= inputs_right; zeros_right++)
202     {
203       if (zeros_right > 0)
204         test_pattern[inputs_left + zeros_right - 1] = 0;
205
206       /* Copy the current pattern and let the network sort it */
207       memcpy (values, test_pattern, sizeof (values));
208       status = sn_network_sort (n, values);
209       if (status != 0)
210         return (status);
211
212       /* Check if the array is now sorted. */
213       previous = values[0];
214       status = 0;
215       for (i = 1; i < n->inputs_num; i++)
216       {
217         if (previous > values[i])
218         {
219           patterns_failed++;
220           status = -1;
221           break;
222         }
223         previous = values[i];
224       }
225     } /* for (zeros_right) */
226   } /* for (zeros_left) */
227
228   return (patterns_failed);
229 } /* }}} int rate_network */
230 #else
231 static int rate_network (sn_network_t *n) /* {{{ */
232 {
233   int test_pattern[n->inputs_num];
234   int values[n->inputs_num];
235
236   int patterns_sorted = 0;
237   int patterns_failed = 0;
238
239   memset (test_pattern, 0, sizeof (test_pattern));
240   while (42)
241   {
242     int previous;
243     int overflow;
244     int status;
245     int i;
246
247     /* Copy the current pattern and let the network sort it */
248     memcpy (values, test_pattern, sizeof (values));
249     status = sn_network_sort (n, values);
250     if (status != 0)
251       return (status);
252
253     /* Check if the array is now sorted. */
254     previous = values[0];
255     status = 0;
256     for (i = 1; i < n->inputs_num; i++)
257     {
258       if (previous > values[i])
259       {
260         patterns_failed++;
261         status = -1;
262         break;
263       }
264       previous = values[i];
265     }
266
267     if (status == 0)
268       patterns_sorted++;
269
270     /* Generate the next test pattern */
271     overflow = 1;
272     for (i = 0; i < n->inputs_num; i++)
273     {
274       if (test_pattern[i] == 0)
275       {
276         test_pattern[i] = 1;
277         overflow = 0;
278         break;
279       }
280       else
281       {
282         test_pattern[i] = 0;
283         overflow = 1;
284       }
285     }
286
287     /* Break out of the while loop if we tested all possible patterns */
288     if (overflow == 1)
289       break;
290   } /* while (42) */
291
292   /* All tests successfull */
293   return (patterns_failed);
294 } /* }}} int rate_network */
295 #endif
296
297 static _Bool sn_bound (sn_network_t *n, int depth, int rating) /* {{{ */
298 {
299   static int least_failed = INT_MAX;
300
301   int lower_bound;
302
303   assert (depth <= max_depth);
304
305   if (SN_NETWORK_STAGE_NUM (n) > max_stages)
306     return (1);
307
308   /* Minimum number of comparisons requires */
309   lower_bound = (int) (ceil (log ((double) rating) / log (2.0)));
310
311   if (lower_bound > (max_depth - depth))
312     return (1);
313
314   if (least_failed >= rating)
315   {
316     printf ("New optimum: %i\n", rating);
317     sn_network_show (n);
318     printf ("===\n\n");
319     fflush (stdout);
320     least_failed = rating;
321
322     /* FIXME */
323     if (rating == 0)
324       exit (EXIT_SUCCESS);
325   }
326
327   return (0);
328 } /* }}} _Bool sn_bound */
329
330 static int sn_branch (sn_network_t *n, int depth, int rating) /* {{{ */
331 {
332   int left_num;
333   int left_rnd;
334   int i;
335
336   if (depth >= max_depth)
337     return (-1);
338
339   if (rating < 0)
340     rating = rate_network (n);
341
342   left_num = SN_NETWORK_INPUT_NUM (n) - 1;
343   left_rnd = sn_bounded_random (0, left_num - 1);
344
345   for (i = 0; i < left_num; i++)
346   {
347     int left_input;
348     int right_num;
349     int right_rnd;
350     int j;
351
352     left_input = left_rnd + i;
353     if (left_input >= left_num)
354       left_input -= left_num;
355     assert (left_input < left_num);
356     assert (left_input >= 0);
357
358     right_num = SN_NETWORK_INPUT_NUM (n) - (left_input + 1);
359     if (right_num <= 1)
360       right_rnd = 0;
361     else
362       right_rnd = sn_bounded_random (0, right_num - 1);
363
364     for (j = 0; j < right_num; j++)
365     {
366       sn_network_t *n_copy;
367       sn_comparator_t *c;
368       int n_copy_rating;
369       int right_input;
370
371       right_input = (left_input + 1) + right_rnd + j;
372       if (right_input >= SN_NETWORK_INPUT_NUM (n))
373         right_input -= right_num;
374       assert (right_input < SN_NETWORK_INPUT_NUM (n));
375       assert (right_input > left_input);
376
377       n_copy = sn_network_clone (n);
378       c = sn_comparator_create (left_input, right_input);
379
380       sn_network_comparator_add (n_copy, c);
381
382       /* Make sure the new comparator is improving the network */
383       n_copy_rating = rate_network (n_copy);
384       if (n_copy_rating >= rating)
385       {
386         sn_comparator_destroy (c);
387         sn_network_destroy (n_copy);
388         continue;
389       }
390
391       if (!sn_bound (n_copy, depth + 1, n_copy_rating))
392         sn_branch (n_copy, depth + 1, n_copy_rating);
393
394       sn_comparator_destroy (c);
395       sn_network_destroy (n_copy);
396     } /* for (right_input) */
397   } /* for (left_input) */
398
399   return (0);
400 } /* }}} int sn_branch */
401
402 int main (int argc, char **argv) /* {{{ */
403 {
404   sn_network_t *n;
405   int c_num;
406 #if BUILD_MERGE
407   int inputs_num;
408 #endif
409
410   read_options (argc, argv);
411
412 #if BUILD_MERGE
413   if ((inputs_left <= 0) || (inputs_right <= 0) || (comparators_num <= 0))
414     exit_usage (argv[0]);
415   inputs_num = inputs_left + inputs_right;
416 #else
417   if ((inputs_num <= 0) || (comparators_num <= 0))
418     exit_usage (argv[0]);
419 #endif
420
421   if (initial_input_file != NULL)
422   {
423     n = sn_network_read_file (initial_input_file);
424     if (n == NULL)
425     {
426       fprintf (stderr, "Cannot read network from `%s'.\n",
427           initial_input_file);
428       exit (EXIT_FAILURE);
429     }
430
431     if (n->inputs_num != inputs_num)
432     {
433       fprintf (stderr, "Network `%s' has %i inputs, but %i were configured "
434           "on the command line.\n",
435           initial_input_file, n->inputs_num, inputs_num);
436
437       exit (EXIT_FAILURE);
438     }
439
440     if (SN_NETWORK_STAGE_NUM (n) > max_stages)
441     {
442       fprintf (stderr, "The maximum number of stages allowed (%i) is smaller "
443           "than the number of stages of the input file (%i).\n",
444           max_stages, SN_NETWORK_STAGE_NUM (n));
445       exit (EXIT_FAILURE);
446     }
447   }
448   else /* if (initial_input_file == NULL) */
449   {
450     n = sn_network_create (inputs_num);
451   }
452
453   c_num = sn_network_get_comparator_num (n);
454   if (c_num >= comparators_num)
455   {
456     fprintf (stderr, "The initial network has already %i comparators, "
457         "which conflicts with the -c option (%i)!\n",
458         c_num, comparators_num);
459     exit (EXIT_FAILURE);
460   }
461   max_depth = comparators_num - c_num;
462
463   printf ("Current configuration:\n"
464       "  Number of inputs:      %3i\n"
465       "  Number of comparators: %3i\n"
466       "  Search depth:          %3i\n"
467       "=======================\n",
468       inputs_num, comparators_num, max_depth);
469
470   sn_branch (n, /* depth = */ 0, /* rating = */ -1);
471
472   return (0);
473 } /* }}} int main */
474
475 /* vim: set sw=2 sts=2 et fdm=marker : */