git-clone: Keep remote names when cloning unless explicitly told not to.
[git.git] / gitMergeCommon.py
1 import sys, re, os, traceback
2 from sets import Set
3
4 def die(*args):
5     printList(args, sys.stderr)
6     sys.exit(2)
7
8 def printList(list, file=sys.stdout):
9     for x in list:
10         file.write(str(x))
11         file.write(' ')
12     file.write('\n')
13
14 import subprocess
15
16 # Debugging machinery
17 # -------------------
18
19 DEBUG = 0
20 functionsToDebug = Set()
21
22 def addDebug(func):
23     if type(func) == str:
24         functionsToDebug.add(func)
25     else:
26         functionsToDebug.add(func.func_name)
27
28 def debug(*args):
29     if DEBUG:
30         funcName = traceback.extract_stack()[-2][2]
31         if funcName in functionsToDebug:
32             printList(args)
33
34 # Program execution
35 # -----------------
36
37 class ProgramError(Exception):
38     def __init__(self, progStr, error):
39         self.progStr = progStr
40         self.error = error
41
42     def __str__(self):
43         return self.progStr + ': ' + self.error
44
45 addDebug('runProgram')
46 def runProgram(prog, input=None, returnCode=False, env=None, pipeOutput=True):
47     debug('runProgram prog:', str(prog), 'input:', str(input))
48     if type(prog) is str:
49         progStr = prog
50     else:
51         progStr = ' '.join(prog)
52     
53     try:
54         if pipeOutput:
55             stderr = subprocess.STDOUT
56             stdout = subprocess.PIPE
57         else:
58             stderr = None
59             stdout = None
60         pop = subprocess.Popen(prog,
61                                shell = type(prog) is str,
62                                stderr=stderr,
63                                stdout=stdout,
64                                stdin=subprocess.PIPE,
65                                env=env)
66     except OSError, e:
67         debug('strerror:', e.strerror)
68         raise ProgramError(progStr, e.strerror)
69
70     if input != None:
71         pop.stdin.write(input)
72     pop.stdin.close()
73
74     if pipeOutput:
75         out = pop.stdout.read()
76     else:
77         out = ''
78
79     code = pop.wait()
80     if returnCode:
81         ret = [out, code]
82     else:
83         ret = out
84     if code != 0 and not returnCode:
85         debug('error output:', out)
86         debug('prog:', prog)
87         raise ProgramError(progStr, out)
88 #    debug('output:', out.replace('\0', '\n'))
89     return ret
90
91 # Code for computing common ancestors
92 # -----------------------------------
93
94 currentId = 0
95 def getUniqueId():
96     global currentId
97     currentId += 1
98     return currentId
99
100 # The 'virtual' commit objects have SHAs which are integers
101 shaRE = re.compile('^[0-9a-f]{40}$')
102 def isSha(obj):
103     return (type(obj) is str and bool(shaRE.match(obj))) or \
104            (type(obj) is int and obj >= 1)
105
106 class Commit:
107     def __init__(self, sha, parents, tree=None):
108         self.parents = parents
109         self.firstLineMsg = None
110         self.children = []
111
112         if tree:
113             tree = tree.rstrip()
114             assert(isSha(tree))
115         self._tree = tree
116
117         if not sha:
118             self.sha = getUniqueId()
119             self.virtual = True
120             self.firstLineMsg = 'virtual commit'
121             assert(isSha(tree))
122         else:
123             self.virtual = False
124             self.sha = sha.rstrip()
125         assert(isSha(self.sha))
126
127     def tree(self):
128         self.getInfo()
129         assert(self._tree != None)
130         return self._tree
131
132     def shortInfo(self):
133         self.getInfo()
134         return str(self.sha) + ' ' + self.firstLineMsg
135
136     def __str__(self):
137         return self.shortInfo()
138
139     def getInfo(self):
140         if self.virtual or self.firstLineMsg != None:
141             return
142         else:
143             info = runProgram(['git-cat-file', 'commit', self.sha])
144             info = info.split('\n')
145             msg = False
146             for l in info:
147                 if msg:
148                     self.firstLineMsg = l
149                     break
150                 else:
151                     if l.startswith('tree'):
152                         self._tree = l[5:].rstrip()
153                     elif l == '':
154                         msg = True
155
156 class Graph:
157     def __init__(self):
158         self.commits = []
159         self.shaMap = {}
160
161     def addNode(self, node):
162         assert(isinstance(node, Commit))
163         self.shaMap[node.sha] = node
164         self.commits.append(node)
165         for p in node.parents:
166             p.children.append(node)
167         return node
168
169     def reachableNodes(self, n1, n2):
170         res = {}
171         def traverse(n):
172             res[n] = True
173             for p in n.parents:
174                 traverse(p)
175
176         traverse(n1)
177         traverse(n2)
178         return res
179
180     def fixParents(self, node):
181         for x in range(0, len(node.parents)):
182             node.parents[x] = self.shaMap[node.parents[x]]
183
184 # addDebug('buildGraph')
185 def buildGraph(heads):
186     debug('buildGraph heads:', heads)
187     for h in heads:
188         assert(isSha(h))
189
190     g = Graph()
191
192     out = runProgram(['git-rev-list', '--parents'] + heads)
193     for l in out.split('\n'):
194         if l == '':
195             continue
196         shas = l.split(' ')
197
198         # This is a hack, we temporarily use the 'parents' attribute
199         # to contain a list of SHA1:s. They are later replaced by proper
200         # Commit objects.
201         c = Commit(shas[0], shas[1:])
202
203         g.commits.append(c)
204         g.shaMap[c.sha] = c
205
206     for c in g.commits:
207         g.fixParents(c)
208
209     for c in g.commits:
210         for p in c.parents:
211             p.children.append(c)
212     return g
213
214 # Write the empty tree to the object database and return its SHA1
215 def writeEmptyTree():
216     tmpIndex = os.environ.get('GIT_DIR', '.git') + '/merge-tmp-index'
217     def delTmpIndex():
218         try:
219             os.unlink(tmpIndex)
220         except OSError:
221             pass
222     delTmpIndex()
223     newEnv = os.environ.copy()
224     newEnv['GIT_INDEX_FILE'] = tmpIndex
225     res = runProgram(['git-write-tree'], env=newEnv).rstrip()
226     delTmpIndex()
227     return res
228
229 def addCommonRoot(graph):
230     roots = []
231     for c in graph.commits:
232         if len(c.parents) == 0:
233             roots.append(c)
234
235     superRoot = Commit(sha=None, parents=[], tree=writeEmptyTree())
236     graph.addNode(superRoot)
237     for r in roots:
238         r.parents = [superRoot]
239     superRoot.children = roots
240     return superRoot
241
242 def getCommonAncestors(graph, commit1, commit2):
243     '''Find the common ancestors for commit1 and commit2'''
244     assert(isinstance(commit1, Commit) and isinstance(commit2, Commit))
245
246     def traverse(start, set):
247         stack = [start]
248         while len(stack) > 0:
249             el = stack.pop()
250             set.add(el)
251             for p in el.parents:
252                 if p not in set:
253                     stack.append(p)
254     h1Set = Set()
255     h2Set = Set()
256     traverse(commit1, h1Set)
257     traverse(commit2, h2Set)
258     shared = h1Set.intersection(h2Set)
259
260     if len(shared) == 0:
261         shared = [addCommonRoot(graph)]
262         
263     res = Set()
264
265     for s in shared:
266         if len([c for c in s.children if c in shared]) == 0:
267             res.add(s)
268     return list(res)