Coverage for /builds/hweiske/ase/ase/calculators/socketio.py: 91.88%
394 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
1import os
2import socket
3from contextlib import contextmanager
4from subprocess import PIPE, Popen
6import numpy as np
8import ase.units as units
9from ase.calculators.calculator import (ArgvProfile, Calculator,
10 OldShellProfile,
11 PropertyNotImplementedError,
12 all_changes)
13from ase.calculators.genericfileio import GenericFileIOCalculator
14from ase.parallel import world
15from ase.stress import full_3x3_to_voigt_6_stress
16from ase.utils import IOContext
19def actualunixsocketname(name):
20 return f'/tmp/ipi_{name}'
23class SocketClosed(OSError):
24 pass
27class IPIProtocol:
28 """Communication using IPI protocol."""
30 def __init__(self, socket, txt=None):
31 self.socket = socket
33 if txt is None:
34 def log(*args):
35 pass
36 else:
37 def log(*args):
38 print('Driver:', *args, file=txt)
39 txt.flush()
40 self.log = log
42 def sendmsg(self, msg):
43 self.log(' sendmsg', repr(msg))
44 # assert msg in self.statements, msg
45 msg = msg.encode('ascii').ljust(12)
46 self.socket.sendall(msg)
48 def _recvall(self, nbytes):
49 """Repeatedly read chunks until we have nbytes.
51 Normally we get all bytes in one read, but that is not guaranteed."""
52 remaining = nbytes
53 chunks = []
54 while remaining > 0:
55 chunk = self.socket.recv(remaining)
56 if len(chunk) == 0:
57 # (If socket is still open, recv returns at least one byte)
58 raise SocketClosed
59 chunks.append(chunk)
60 remaining -= len(chunk)
61 msg = b''.join(chunks)
62 assert len(msg) == nbytes and remaining == 0
63 return msg
65 def recvmsg(self):
66 msg = self._recvall(12)
67 if not msg:
68 raise SocketClosed
70 assert len(msg) == 12, msg
71 msg = msg.rstrip().decode('ascii')
72 # assert msg in self.responses, msg
73 self.log(' recvmsg', repr(msg))
74 return msg
76 def send(self, a, dtype):
77 buf = np.asarray(a, dtype).tobytes()
78 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
79 self.log(f' send {len(buf)} bytes of {dtype}')
80 self.socket.sendall(buf)
82 def recv(self, shape, dtype):
83 a = np.empty(shape, dtype)
84 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
85 buf = self._recvall(nbytes)
86 assert len(buf) == nbytes, (len(buf), nbytes)
87 self.log(f' recv {len(buf)} bytes of {dtype}')
88 # print(np.frombuffer(buf, dtype=dtype))
89 a.flat[:] = np.frombuffer(buf, dtype=dtype)
90 # self.log(' recv {}'.format(a.ravel().tolist()))
91 assert np.isfinite(a).all()
92 return a
94 def sendposdata(self, cell, icell, positions):
95 assert cell.size == 9
96 assert icell.size == 9
97 assert positions.size % 3 == 0
99 self.log(' sendposdata')
100 self.sendmsg('POSDATA')
101 self.send(cell.T / units.Bohr, np.float64)
102 self.send(icell.T * units.Bohr, np.float64)
103 self.send(len(positions), np.int32)
104 self.send(positions / units.Bohr, np.float64)
106 def recvposdata(self):
107 cell = self.recv((3, 3), np.float64).T.copy()
108 icell = self.recv((3, 3), np.float64).T.copy()
109 natoms = self.recv(1, np.int32)[0]
110 positions = self.recv((natoms, 3), np.float64)
111 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
113 def sendrecv_force(self):
114 self.log(' sendrecv_force')
115 self.sendmsg('GETFORCE')
116 msg = self.recvmsg()
117 assert msg == 'FORCEREADY', msg
118 e = self.recv(1, np.float64)[0]
119 natoms = self.recv(1, np.int32)[0]
120 assert natoms >= 0
121 forces = self.recv((int(natoms), 3), np.float64)
122 virial = self.recv((3, 3), np.float64).T.copy()
123 nmorebytes = self.recv(1, np.int32)[0]
124 morebytes = self.recv(nmorebytes, np.byte)
125 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
126 units.Ha * virial, morebytes)
128 def sendforce(self, energy, forces, virial,
129 morebytes=np.zeros(1, dtype=np.byte)):
130 assert np.array([energy]).size == 1
131 assert forces.shape[1] == 3
132 assert virial.shape == (3, 3)
134 self.log(' sendforce')
135 self.sendmsg('FORCEREADY') # mind the units
136 self.send(np.array([energy / units.Ha]), np.float64)
137 natoms = len(forces)
138 self.send(np.array([natoms]), np.int32)
139 self.send(units.Bohr / units.Ha * forces, np.float64)
140 self.send(1.0 / units.Ha * virial.T, np.float64)
141 # We prefer to always send at least one byte due to trouble with
142 # empty messages. Reading a closed socket yields 0 bytes
143 # and thus can be confused with a 0-length bytestring.
144 self.send(np.array([len(morebytes)]), np.int32)
145 self.send(morebytes, np.byte)
147 def status(self):
148 self.log(' status')
149 self.sendmsg('STATUS')
150 msg = self.recvmsg()
151 return msg
153 def end(self):
154 self.log(' end')
155 self.sendmsg('EXIT')
157 def recvinit(self):
158 self.log(' recvinit')
159 bead_index = self.recv(1, np.int32)
160 nbytes = self.recv(1, np.int32)
161 initbytes = self.recv(nbytes, np.byte)
162 return bead_index, initbytes
164 def sendinit(self):
165 # XXX Not sure what this function is supposed to send.
166 # It 'works' with QE, but for now we try not to call it.
167 self.log(' sendinit')
168 self.sendmsg('INIT')
169 self.send(0, np.int32) # 'bead index' always zero for now
170 # We send one byte, which is zero, since things may not work
171 # with 0 bytes. Apparently implementations ignore the
172 # initialization string anyway.
173 self.send(1, np.int32)
174 self.send(np.zeros(1), np.byte) # initialization string
176 def calculate(self, positions, cell):
177 self.log('calculate')
178 msg = self.status()
179 # We don't know how NEEDINIT is supposed to work, but some codes
180 # seem to be okay if we skip it and send the positions instead.
181 if msg == 'NEEDINIT':
182 self.sendinit()
183 msg = self.status()
184 assert msg == 'READY', msg
185 icell = np.linalg.pinv(cell).transpose()
186 self.sendposdata(cell, icell, positions)
187 msg = self.status()
188 assert msg == 'HAVEDATA', msg
189 e, forces, virial, morebytes = self.sendrecv_force()
190 r = dict(energy=e,
191 forces=forces,
192 virial=virial,
193 morebytes=morebytes)
194 return r
197@contextmanager
198def bind_unixsocket(socketfile):
199 assert socketfile.startswith('/tmp/ipi_'), socketfile
200 serversocket = socket.socket(socket.AF_UNIX)
201 try:
202 serversocket.bind(socketfile)
203 except OSError as err:
204 raise OSError(f'{err}: {socketfile!r}')
206 try:
207 with serversocket:
208 yield serversocket
209 finally:
210 os.unlink(socketfile)
213@contextmanager
214def bind_inetsocket(port):
215 serversocket = socket.socket(socket.AF_INET)
216 serversocket.setsockopt(socket.SOL_SOCKET,
217 socket.SO_REUSEADDR, 1)
218 serversocket.bind(('', port))
219 with serversocket:
220 yield serversocket
223class FileIOSocketClientLauncher:
224 def __init__(self, calc):
225 self.calc = calc
227 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
228 assert self.calc is not None
229 cwd = self.calc.directory
231 profile = getattr(self.calc, 'profile', None)
232 if isinstance(self.calc, GenericFileIOCalculator):
233 # New GenericFileIOCalculator:
235 self.calc.write_inputfiles(atoms, properties)
236 if unixsocket is not None:
237 argv = profile.socketio_argv_unix(socket=unixsocket)
238 else:
239 argv = profile.socketio_argv_inet(port=port)
240 return Popen(argv, cwd=cwd, env=os.environ)
241 else:
242 # Old FileIOCalculator:
243 self.calc.write_input(atoms, properties=properties,
244 system_changes=all_changes)
246 if profile is None:
247 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
248 cmd = cmd.format(port=port, unixsocket=unixsocket)
249 elif isinstance(profile, OldShellProfile):
250 cmd = profile.command.replace("PREFIX", self.calc.prefix)
251 return Popen(cmd, shell=True, cwd=cwd)
252 elif isinstance(profile, ArgvProfile):
253 return profile.execute_nonblocking(self.calc)
256class SocketServer(IOContext):
257 default_port = 31415
259 def __init__(self, # launch_client=None,
260 port=None, unixsocket=None, timeout=None,
261 log=None):
262 """Create server and listen for connections.
264 Parameters:
266 client_command: Shell command to launch client process, or None
267 The process will be launched immediately, if given.
268 Else the user is expected to launch a client whose connection
269 the server will then accept at any time.
270 One calculate() is called, the server will block to wait
271 for the client.
272 port: integer or None
273 Port on which to listen for INET connections. Defaults
274 to 31415 if neither this nor unixsocket is specified.
275 unixsocket: string or None
276 Filename for unix socket.
277 timeout: float or None
278 timeout in seconds, or unlimited by default.
279 This parameter is passed to the Python socket object; see
280 documentation therof
281 log: file object or None
282 useful debug messages are written to this."""
284 if unixsocket is None and port is None:
285 port = self.default_port
286 elif unixsocket is not None and port is not None:
287 raise ValueError('Specify only one of unixsocket and port')
289 self.port = port
290 self.unixsocket = unixsocket
291 self.timeout = timeout
292 self._closed = False
294 if unixsocket is not None:
295 actualsocket = actualunixsocketname(unixsocket)
296 conn_name = f'UNIX-socket {actualsocket}'
297 socket_context = bind_unixsocket(actualsocket)
298 else:
299 conn_name = f'INET port {port}'
300 socket_context = bind_inetsocket(port)
302 self.serversocket = self.closelater(socket_context)
304 if log:
305 print(f'Accepting clients on {conn_name}', file=log)
307 self.serversocket.settimeout(timeout)
309 self.serversocket.listen(1)
311 self.log = log
313 self.proc = None
315 self.protocol = None
316 self.clientsocket = None
317 self.address = None
319 # if launch_client is not None:
320 # self.proc = launch_client(port=port, unixsocket=unixsocket)
322 def _accept(self):
323 """Wait for client and establish connection."""
324 # It should perhaps be possible for process to be launched by user
325 log = self.log
326 if log:
327 print('Awaiting client', file=self.log)
329 # If we launched the subprocess, the process may crash.
330 # We want to detect this, using loop with timeouts, and
331 # raise an error rather than blocking forever.
332 if self.proc is not None:
333 self.serversocket.settimeout(1.0)
335 while True:
336 try:
337 self.clientsocket, self.address = self.serversocket.accept()
338 self.closelater(self.clientsocket)
339 except socket.timeout:
340 if self.proc is not None:
341 status = self.proc.poll()
342 if status is not None:
343 raise OSError('Subprocess terminated unexpectedly'
344 ' with status {}'.format(status))
345 else:
346 break
348 self.serversocket.settimeout(self.timeout)
349 self.clientsocket.settimeout(self.timeout)
351 if log:
352 # For unix sockets, address is b''.
353 source = ('client' if self.address == b'' else self.address)
354 print(f'Accepted connection from {source}', file=log)
356 self.protocol = IPIProtocol(self.clientsocket, txt=log)
358 def close(self):
359 if self._closed:
360 return
362 super().close()
364 if self.log:
365 print('Close socket server', file=self.log)
366 self._closed = True
368 # Proper way to close sockets?
369 # And indeed i-pi connections...
370 # if self.protocol is not None:
371 # self.protocol.end() # Send end-of-communication string
372 self.protocol = None
373 if self.proc is not None:
374 exitcode = self.proc.wait()
375 if exitcode != 0:
376 import warnings
378 # Quantum Espresso seems to always exit with status 128,
379 # even if successful.
380 # Should investigate at some point
381 warnings.warn('Subprocess exited with status {}'
382 .format(exitcode))
383 # self.log('IPI server closed')
385 def calculate(self, atoms):
386 """Send geometry to client and return calculated things as dict.
388 This will block until client has established connection, then
389 wait for the client to finish the calculation."""
390 assert not self._closed
392 # If we have not established connection yet, we must block
393 # until the client catches up:
394 if self.protocol is None:
395 self._accept()
396 return self.protocol.calculate(atoms.positions, atoms.cell)
399class SocketClient:
400 def __init__(self, host='localhost', port=None,
401 unixsocket=None, timeout=None, log=None, comm=world):
402 """Create client and connect to server.
404 Parameters:
406 host: string
407 Hostname of server. Defaults to localhost
408 port: integer or None
409 Port to which to connect. By default 31415.
410 unixsocket: string or None
411 If specified, use corresponding UNIX socket.
412 See documentation of unixsocket for SocketIOCalculator.
413 timeout: float or None
414 See documentation of timeout for SocketIOCalculator.
415 log: file object or None
416 Log events to this file
417 comm: communicator or None
418 MPI communicator object. Defaults to ase.parallel.world.
419 When ASE runs in parallel, only the process with world.rank == 0
420 will communicate over the socket. The received information
421 will then be broadcast on the communicator. The SocketClient
422 must be created on all ranks of world, and will see the same
423 Atoms objects."""
424 # Only rank0 actually does the socket work.
425 # The other ranks only need to follow.
426 #
427 # Note: We actually refrain from assigning all the
428 # socket-related things except on master
429 self.comm = comm
431 if self.comm.rank == 0:
432 if unixsocket is not None:
433 sock = socket.socket(socket.AF_UNIX)
434 actualsocket = actualunixsocketname(unixsocket)
435 sock.connect(actualsocket)
436 else:
437 if port is None:
438 port = SocketServer.default_port
439 sock = socket.socket(socket.AF_INET)
440 sock.connect((host, port))
441 sock.settimeout(timeout)
442 self.host = host
443 self.port = port
444 self.unixsocket = unixsocket
446 self.protocol = IPIProtocol(sock, txt=log)
447 self.log = self.protocol.log
448 self.closed = False
450 self.bead_index = 0
451 self.bead_initbytes = b''
452 self.state = 'READY'
454 def close(self):
455 if not self.closed:
456 self.log('Close SocketClient')
457 self.closed = True
458 self.protocol.socket.close()
460 def calculate(self, atoms, use_stress):
461 # We should also broadcast the bead index, once we support doing
462 # multiple beads.
463 self.comm.broadcast(atoms.positions, 0)
464 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
466 energy = atoms.get_potential_energy()
467 forces = atoms.get_forces()
468 if use_stress:
469 stress = atoms.get_stress(voigt=False)
470 virial = -atoms.get_volume() * stress
471 else:
472 virial = np.zeros((3, 3))
473 return energy, forces, virial
475 def irun(self, atoms, use_stress=None):
476 if use_stress is None:
477 use_stress = any(atoms.pbc)
479 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
480 return my_irun(atoms, use_stress)
482 def irun_rankN(self, atoms, use_stress=True):
483 stop_criterion = np.zeros(1, bool)
484 while True:
485 self.comm.broadcast(stop_criterion, 0)
486 if stop_criterion[0]:
487 return
489 self.calculate(atoms, use_stress)
490 yield
492 def irun_rank0(self, atoms, use_stress=True):
493 # For every step we either calculate or quit. We need to
494 # tell other MPI processes (if this is MPI-parallel) whether they
495 # should calculate or quit.
496 try:
497 while True:
498 try:
499 msg = self.protocol.recvmsg()
500 except SocketClosed:
501 # Server closed the connection, but we want to
502 # exit gracefully anyway
503 msg = 'EXIT'
505 if msg == 'EXIT':
506 # Send stop signal to clients:
507 self.comm.broadcast(np.ones(1, bool), 0)
508 # (When otherwise exiting, things crashed and we should
509 # let MPI_ABORT take care of the mess instead of trying
510 # to synchronize the exit)
511 return
512 elif msg == 'STATUS':
513 self.protocol.sendmsg(self.state)
514 elif msg == 'POSDATA':
515 assert self.state == 'READY'
516 cell, icell, positions = self.protocol.recvposdata()
517 atoms.cell[:] = cell
518 atoms.positions[:] = positions
520 # User may wish to do something with the atoms object now.
521 # Should we provide option to yield here?
522 #
523 # (In that case we should MPI-synchronize *before*
524 # whereas now we do it after.)
526 # Send signal for other ranks to proceed with calculation:
527 self.comm.broadcast(np.zeros(1, bool), 0)
528 energy, forces, virial = self.calculate(atoms, use_stress)
530 self.state = 'HAVEDATA'
531 yield
532 elif msg == 'GETFORCE':
533 assert self.state == 'HAVEDATA', self.state
534 self.protocol.sendforce(energy, forces, virial)
535 self.state = 'NEEDINIT'
536 elif msg == 'INIT':
537 assert self.state == 'NEEDINIT'
538 bead_index, initbytes = self.protocol.recvinit()
539 self.bead_index = bead_index
540 self.bead_initbytes = initbytes
541 self.state = 'READY'
542 else:
543 raise KeyError('Bad message', msg)
544 finally:
545 self.close()
547 def run(self, atoms, use_stress=False):
548 for _ in self.irun(atoms, use_stress=use_stress):
549 pass
552class SocketIOCalculator(Calculator, IOContext):
553 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
554 supported_changes = {'positions', 'cell'}
556 def __init__(self, calc=None, port=None,
557 unixsocket=None, timeout=None, log=None, *,
558 launch_client=None, comm=world):
559 """Initialize socket I/O calculator.
561 This calculator launches a server which passes atomic
562 coordinates and unit cells to an external code via a socket,
563 and receives energy, forces, and stress in return.
565 ASE integrates this with the Quantum Espresso, FHI-aims and
566 Siesta calculators. This works with any external code that
567 supports running as a client over the i-PI protocol.
569 Parameters:
571 calc: calculator or None
573 If calc is not None, a client process will be launched
574 using calc.command, and the input file will be generated
575 using ``calc.write_input()``. Otherwise only the server will
576 run, and it is up to the user to launch a compliant client
577 process.
579 port: integer
581 port number for socket. Should normally be between 1025
582 and 65535. Typical ports for are 31415 (default) or 3141.
584 unixsocket: str or None
586 if not None, ignore host and port, creating instead a
587 unix socket using this name prefixed with ``/tmp/ipi_``.
588 The socket is deleted when the calculator is closed.
590 timeout: float >= 0 or None
592 timeout for connection, by default infinite. See
593 documentation of Python sockets. For longer jobs it is
594 recommended to set a timeout in case of undetected
595 client-side failure.
597 log: file object or None (default)
599 logfile for communication over socket. For debugging or
600 the curious.
602 In order to correctly close the sockets, it is
603 recommended to use this class within a with-block:
605 >>> from ase.calculators.socketio import SocketIOCalculator
607 >>> with SocketIOCalculator(...) as calc: # doctest:+SKIP
608 ... atoms.calc = calc
609 ... atoms.get_forces()
610 ... atoms.rattle()
611 ... atoms.get_forces()
613 It is also possible to call calc.close() after
614 use. This is best done in a finally-block."""
616 Calculator.__init__(self)
618 if calc is not None:
619 if launch_client is not None:
620 raise ValueError('Cannot pass both calc and launch_client')
621 launch_client = FileIOSocketClientLauncher(calc)
622 self.launch_client = launch_client
623 self.timeout = timeout
624 self.server = None
626 self.log = self.openfile(file=log, comm=comm)
628 # We only hold these so we can pass them on to the server.
629 # They may both be None as stored here.
630 self._port = port
631 self._unixsocket = unixsocket
633 # If there is a calculator, we will launch in calculate() because
634 # we are responsible for executing the external process, too, and
635 # should do so before blocking. Without a calculator we want to
636 # block immediately:
637 if self.launch_client is None:
638 self.server = self.launch_server()
640 def todict(self):
641 d = {'type': 'calculator',
642 'name': 'socket-driver'}
643 # if self.calc is not None:
644 # d['calc'] = self.calc.todict()
645 return d
647 def launch_server(self):
648 return self.closelater(SocketServer(
649 # launch_client=launch_client,
650 port=self._port,
651 unixsocket=self._unixsocket,
652 timeout=self.timeout, log=self.log,
653 ))
655 def calculate(self, atoms=None, properties=['energy'],
656 system_changes=all_changes):
657 bad = [change for change in system_changes
658 if change not in self.supported_changes]
660 # First time calculate() is called, system_changes will be
661 # all_changes. After that, only positions and cell may change.
662 if self.atoms is not None and any(bad):
663 raise PropertyNotImplementedError(
664 'Cannot change {} through IPI protocol. '
665 'Please create new socket calculator.'
666 .format(bad if len(bad) > 1 else bad[0]))
668 self.atoms = atoms.copy()
670 if self.server is None:
671 self.server = self.launch_server()
672 proc = self.launch_client(atoms, properties,
673 port=self._port,
674 unixsocket=self._unixsocket)
675 self.server.proc = proc # XXX nasty hack
677 results = self.server.calculate(atoms)
678 results['free_energy'] = results['energy']
679 virial = results.pop('virial')
680 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
681 vol = atoms.get_volume()
682 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
683 self.results.update(results)
685 def close(self):
686 self.server = None
687 super().close()
690class PySocketIOClient:
691 def __init__(self, calculator_factory):
692 self._calculator_factory = calculator_factory
694 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
695 import pickle
696 import sys
698 # We pickle everything first, so we won't need to bother with the
699 # process as long as it succeeds.
700 transferbytes = pickle.dumps([
701 dict(unixsocket=unixsocket, port=port),
702 atoms.copy(),
703 self._calculator_factory,
704 ])
706 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
707 stdin=PIPE)
709 proc.stdin.write(transferbytes)
710 proc.stdin.close()
711 return proc
713 @staticmethod
714 def main():
715 import pickle
716 import sys
718 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
719 atoms.calc = get_calculator()
720 client = SocketClient(host='localhost',
721 unixsocket=socketinfo.get('unixsocket'),
722 port=socketinfo.get('port'))
723 # XXX In principle we could avoid calculating stress until
724 # someone requests the stress, could we not?
725 # Which would make use_stress boolean unnecessary.
726 client.run(atoms, use_stress=True)
729if __name__ == '__main__':
730 PySocketIOClient.main()