Add approxidate test calls.
[git.git] / gitMergeCommon.py
1 #
2 # Copyright (C) 2005 Fredrik Kuivinen
3 #
4
5 import sys, re, os, traceback
6 from sets import Set
7
8 def die(*args):
9     printList(args, sys.stderr)
10     sys.exit(2)
11
12 def printList(list, file=sys.stdout):
13     for x in list:
14         file.write(str(x))
15         file.write(' ')
16     file.write('\n')
17
18 import subprocess
19
20 # Debugging machinery
21 # -------------------
22
23 DEBUG = 0
24 functionsToDebug = Set()
25
26 def addDebug(func):
27     if type(func) == str:
28         functionsToDebug.add(func)
29     else:
30         functionsToDebug.add(func.func_name)
31
32 def debug(*args):
33     if DEBUG:
34         funcName = traceback.extract_stack()[-2][2]
35         if funcName in functionsToDebug:
36             printList(args)
37
38 # Program execution
39 # -----------------
40
41 class ProgramError(Exception):
42     def __init__(self, progStr, error):
43         self.progStr = progStr
44         self.error = error
45
46     def __str__(self):
47         return self.progStr + ': ' + self.error
48
49 addDebug('runProgram')
50 def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
51     debug('runProgram prog:', str(prog), 'input:', str(input))
52     if type(prog) is str:
53         progStr = prog
54     else:
55         progStr = ' '.join(prog)
56     
57     try:
58         if pipeOutput:
59             stderr = subprocess.STDOUT
60             stdout = subprocess.PIPE
61         else:
62             stderr = None
63             stdout = None
64         pop = subprocess.Popen(prog,
65                                shell = type(prog) is str,
66                                stderr=stderr,
67                                stdout=stdout,
68                                stdin=subprocess.PIPE,
69                                env=env)
70     except OSError, e:
71         debug('strerror:', e.strerror)
72         raise ProgramError(progStr, e.strerror)
73
74     if input != None:
75         pop.stdin.write(input)
76     pop.stdin.close()
77
78     if pipeOutput:
79         out = pop.stdout.read()
80     else:
81         out = ''
82
83     code = pop.wait()
84     if returnCode:
85         ret = [out, code]
86     else:
87         ret = out
88     if code != 0 and not returnCode:
89         debug('error output:', out)
90         debug('prog:', prog)
91         raise ProgramError(progStr, out)
92 #    debug('output:', out.replace('\0', '\n'))
93     return ret
94
95 # Code for computing common ancestors
96 # -----------------------------------
97
98 currentId = 0
99 def getUniqueId():
100     global currentId
101     currentId += 1
102     return currentId
103
104 # The 'virtual' commit objects have SHAs which are integers
105 shaRE = re.compile('^[0-9a-f]{40}$')
106 def isSha(obj):
107     return (type(obj) is str and bool(shaRE.match(obj))) or \
108            (type(obj) is int and obj >= 1)
109
110 class Commit:
111     def __init__(self, sha, parents, tree=None):
112         self.parents = parents
113         self.firstLineMsg = None
114         self.children = []
115
116         if tree:
117             tree = tree.rstrip()
118             assert(isSha(tree))
119         self._tree = tree
120
121         if not sha:
122             self.sha = getUniqueId()
123             self.virtual = True
124             self.firstLineMsg = 'virtual commit'
125             assert(isSha(tree))
126         else:
127             self.virtual = False
128             self.sha = sha.rstrip()
129         assert(isSha(self.sha))
130
131     def tree(self):
132         self.getInfo()
133         assert(self._tree != None)
134         return self._tree
135
136     def shortInfo(self):
137         self.getInfo()
138         return str(self.sha) + ' ' + self.firstLineMsg
139
140     def __str__(self):
141         return self.shortInfo()
142
143     def getInfo(self):
144         if self.virtual or self.firstLineMsg != None:
145             return
146         else:
147             info = runProgram(['git-cat-file', 'commit', self.sha])
148             info = info.split('\n')
149             msg = False
150             for l in info:
151                 if msg:
152                     self.firstLineMsg = l
153                     break
154                 else:
155                     if l.startswith('tree'):
156                         self._tree = l[5:].rstrip()
157                     elif l == '':
158                         msg = True
159
160 class Graph:
161     def __init__(self):
162         self.commits = []
163         self.shaMap = {}
164
165     def addNode(self, node):
166         assert(isinstance(node, Commit))
167         self.shaMap[node.sha] = node
168         self.commits.append(node)
169         for p in node.parents:
170             p.children.append(node)
171         return node
172
173     def reachableNodes(self, n1, n2):
174         res = {}
175         def traverse(n):
176             res[n] = True
177             for p in n.parents:
178                 traverse(p)
179
180         traverse(n1)
181         traverse(n2)
182         return res
183
184     def fixParents(self, node):
185         for x in range(0, len(node.parents)):
186             node.parents[x] = self.shaMap[node.parents[x]]
187
188 # addDebug('buildGraph')
189 def buildGraph(heads):
190     debug('buildGraph heads:', heads)
191     for h in heads:
192         assert(isSha(h))
193
194     g = Graph()
195
196     out = runProgram(['git-rev-list', '--parents'] + heads)
197     for l in out.split('\n'):
198         if l == '':
199             continue
200         shas = l.split(' ')
201
202         # This is a hack, we temporarily use the 'parents' attribute
203         # to contain a list of SHA1:s. They are later replaced by proper
204         # Commit objects.
205         c = Commit(shas[0], shas[1:])
206
207         g.commits.append(c)
208         g.shaMap[c.sha] = c
209
210     for c in g.commits:
211         g.fixParents(c)
212
213     for c in g.commits:
214         for p in c.parents:
215             p.children.append(c)
216     return g
217
218 # Write the empty tree to the object database and return its SHA1
219 def writeEmptyTree():
220     tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
221     def delTmpIndex():
222         try:
223             os.unlink(tmpIndex)
224         except OSError:
225             pass
226     delTmpIndex()
227     newEnv = os.environ.copy()
228     newEnv['GIT_INDEX_FILE'] = tmpIndex
229     res = runProgram(['git-write-tree'], env=newEnv).rstrip()
230     delTmpIndex()
231     return res
232
233 def addCommonRoot(graph):
234     roots = []
235     for c in graph.commits:
236         if len(c.parents) == 0:
237             roots.append(c)
238
239     superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
240     graph.addNode(superRoot)
241     for r in roots:
242         r.parents = [superRoot]
243     superRoot.children = roots
244     return superRoot
245
246 def getCommonAncestors(graph, commit1, commit2):
247     '''Find the common ancestors for commit1 and commit2'''
248     assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
249
250     def traverse(start, set):
251         stack = [start]
252         while len(stack) > 0:
253             el = stack.pop()
254             set.add(el)
255             for p in el.parents:
256                 if p not in set:
257                     stack.append(p)
258     h1Set = Set()
259     h2Set = Set()
260     traverse(commit1, h1Set)
261     traverse(commit2, h2Set)
262     shared = h1Set.intersection(h2Set)
263
264     if len(shared) == 0:
265         shared = [addCommonRoot(graph)]
266         
267     res = Set()
268
269     for s in shared:
270         if len([c for c in s.children if c in shared]) == 0:
271             res.add(s)
272     return list(res)