]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsat/satadd.c
libsat, forp: fix va_list hack on amd64
[plan9front.git] / sys / src / libsat / satadd.c
1 #include <u.h>
2 #include <libc.h>
3 #include <sat.h>
4 #include "impl.h"
5
6 static SATBlock *
7 newblock(SATSolve *s, int learned)
8 {
9         SATBlock *b;
10         
11         b = calloc(1, SATBLOCKSZ);
12         if(b == nil)
13                 saterror(s, "malloc: %r");
14         b->prev = s->bl[learned].prev;
15         b->next = &s->bl[learned];
16         b->next->prev = b;
17         b->prev->next = b;
18         b->end = (void*) b->data;
19         return b;
20 }
21
22 SATClause *
23 satnewclause(SATSolve *s, int n, int learned)
24 {
25         SATBlock *b;
26         SATClause *c;
27         int f, sz;
28         
29         sz = sizeof(SATClause) + (n - 1) * sizeof(int);
30         assert(sz <= SATBLOCKSZ);
31         if(learned)
32                 b = s->lastbl;
33         else
34                 b = s->bl[0].prev;
35         for(;;){
36                 f = (uchar*)b + SATBLOCKSZ - (uchar*)b->end;
37                 if(f >= sz) break;
38                 b = b->next;
39                 if(b == &s->bl[learned])
40                         b = newblock(s, learned);
41         }
42         c = b->end;
43         memset(c, 0, sizeof(SATClause));
44         b->end = (void *)((uintptr)b->end + sz + CLAUSEALIGN - 1 & -CLAUSEALIGN);
45         b->last = c;
46         if(learned){
47                 if(s->lastp[1] == &s->learncl)
48                         *s->lastp[0] = c;
49                 s->lastbl = b;
50         }else
51                 c->next = s->learncl;
52         *s->lastp[learned] = c;
53         s->lastp[learned] = &c->next;
54         s->ncl++;
55         return c;
56 }
57
58 /* this is currently only used to subsume clauses, i.e. n is guaranteed to be less than the last n */
59 SATClause *
60 satreplclause(SATSolve *s, int n)
61 {
62         SATBlock *b;
63         SATClause *c, **wp;
64         int f, sz, i, l;
65         
66         assert(s->lastbl != nil && s->lastbl->last != nil);
67         b = s->lastbl;
68         c = b->last;
69         f = (uchar*)b + SATBLOCKSZ - (uchar*)c;
70         sz = sizeof(SATClause) + (n - 1) * sizeof(int);
71         assert(f >= sz);
72         b->end = (void *)((uintptr)c + sz + CLAUSEALIGN - 1 & -CLAUSEALIGN);
73         for(i = 0; i < 2; i++){
74                 l = c->l[i];
75                 for(wp = &s->lit[l].watch; *wp != nil && *wp != c; wp = &(*wp)->watch[(*wp)->l[1] == l])
76                         ;
77                 assert(*wp != nil);
78                 *wp = c->watch[i];
79         }
80         memset(c, 0, sizeof(SATClause));
81         return c;
82 }
83
84 static int
85 litconv(SATSolve *s, int l)
86 {
87         int v, m, n;
88         SATVar *vp;
89         SATLit *lp;
90         
91         m = l >> 31;
92         v = (l + m ^ m) - 1;
93         if(v >= s->nvaralloc){
94                 n = -(-(v+1) & -SATVARALLOC);
95                 s->var = vp = satrealloc(s, s->var, n * sizeof(SATVar));
96                 s->lit = lp = satrealloc(s, s->lit, 2 * n * sizeof(SATLit));
97                 memset(vp += s->nvaralloc, 0, (n - s->nvaralloc) * sizeof(SATVar));
98                 memset(lp += 2*s->nvaralloc, 0, 2 * (n - s->nvaralloc) * sizeof(SATLit));
99                 for(; vp < s->var + n; vp++){
100                         vp->lvl = -1;
101                         vp->flags = VARPHASE;
102                 }
103                 for(; lp < s->lit + 2 * n; lp++)
104                         lp->val = -1;
105                 s->nvaralloc = n;
106         }
107         if(v >= s->nvar)
108                 s->nvar = v + 1;
109         return v << 1 | m & 1;
110 }
111
112 static void
113 addbimp(SATSolve *s, int l0, int l1)
114 {
115         SATLit *lp;
116         
117         lp = &s->lit[NOT(l0)];
118         lp->bimp = satrealloc(s, lp->bimp, (lp->nbimp + 1) * sizeof(int));
119         lp->bimp[lp->nbimp++] = l1;
120 }
121
122 static SATSolve *
123 satadd1special(SATSolve *s, int *a, int n)
124 {
125         int i, l0, l1;
126         
127         if(n == 0){
128                 s->unsat = 1;
129                 return s;
130         }
131         l0 = a[0];
132         l1 = 0;
133         for(i = 1; i < n; i++)
134                 if(a[i] != l0){
135                         l1 = a[i];
136                         break;
137                 }
138         if(l1 == 0){
139                 l0 = litconv(s, l0);
140                 assert(s->lvl == 0);
141                 switch(s->lit[l0].val){
142                 case 0:
143                         s->unsat = 1;
144                         return s;
145                 case -1:
146                         s->trail = satrealloc(s, s->trail, sizeof(int) * s->nvar);
147                         memmove(&s->trail[1], s->trail, sizeof(int) * s->ntrail);
148                         s->trail[0] = l0;
149                         s->ntrail++;
150                         s->var[VAR(l0)].flags |= VARUSER;
151                         s->var[VAR(l0)].lvl = 0;
152                         s->lit[l0].val = 1;
153                         s->lit[NOT(l0)].val = 0;
154                 }
155                 return s;
156         }
157         if(l0 + l1 == 0) return s;
158         l0 = litconv(s, l0);
159         l1 = litconv(s, l1);
160         addbimp(s, l0, l1);
161         addbimp(s, l1, l0);
162         return s;
163 }
164
165 SATSolve *
166 satadd1(SATSolve *s, int *a, int n)
167 {
168         SATClause *c;
169         int i, j, l, u;
170         SATVar *v;
171
172         if(s == nil){
173                 s = satnew();
174                 if(s == nil)
175                         saterror(nil, "satnew: %r");
176         }
177         if(n < 0)
178                 for(n = 0; a[n] != 0; n++)
179                         ;
180         for(i = 0; i < n; i++)
181                 if(a[i] == 0)
182                         saterror(s, "satadd1(%p, %p, %d): a[%d]==0, callerpc=%p", s, a, n, i, getcallerpc(&s));
183         satbackjump(s, 0);
184         if(n <= 2)
185                 return satadd1special(s, a, n);
186         /* use stamps to detect repeated literals and tautological clauses */
187         if(s->stamp >= (uint)-6){
188                 for(i = 0; i < s->nvar; i++)
189                         s->var[i].stamp = 0;
190                 s->stamp = 1;
191         }else
192                 s->stamp += 3;
193         u = 0;
194         for(i = 0; i < n; i++){
195                 l = litconv(s, a[i]);
196                 v = &s->var[VAR(l)];
197                 if(v->stamp < s->stamp) u++;
198                 if(v->stamp == s->stamp + (~l & 1))
199                         return s; /* tautological */
200                 v->stamp = s->stamp + (l & 1);
201         }
202         if(u <= 2)
203                 return satadd1special(s, a, n);
204         s->stamp += 3;
205         c = satnewclause(s, u, 0);
206         c->n = u;
207         for(i = 0, j = 0; i < n; i++){
208                 l = litconv(s, a[i]);
209                 v = &s->var[VAR(l)];
210                 if(v->stamp < s->stamp){
211                         c->l[j++] = l;
212                         v->stamp = s->stamp;
213                 }
214         }
215         assert(j == u);
216         s->ncl0++;
217         return s;
218 }
219
220 void
221 satvafix(va_list va)
222 {
223         int *d;
224         uintptr *s;
225
226         if(sizeof(int)==sizeof(uintptr)) return;
227         d = (int *) va;
228         s = (uintptr *) va;
229         do
230                 *d++ = *s;
231         while((int)*s++ != 0);
232                 
233 }
234
235 SATSolve *
236 sataddv(SATSolve *s, ...)
237 {
238         va_list va;
239         
240         va_start(va, s);
241         /* horrible hack */
242         satvafix(va);
243         s = satadd1(s, (int*)va, -1);
244         va_end(va);
245         return s;
246 }