Coverage for /builds/hweiske/ase/ase/calculators/genericfileio.py: 88.16%
152 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-22 11:22 +0000
1from abc import ABC, abstractmethod
2from contextlib import ExitStack
3from os import PathLike
4from pathlib import Path
5from typing import Any, Iterable, List, Mapping, Optional
7from ase.calculators.abc import GetOutputsMixin
8from ase.calculators.calculator import BaseCalculator, EnvironmentError
9from ase.config import cfg as _cfg
12class BaseProfile(ABC):
13 def __init__(self, parallel=True, parallel_info=None):
14 """
15 Parameters
16 ----------
17 parallel : bool
18 If the calculator should be run in parallel.
19 parallel_info : dict
20 Additional settings for parallel execution, e.g. arguments
21 for the binary for parallelization (mpiexec, srun, mpirun).
22 """
23 self.parallel_info = parallel_info or {}
24 self.parallel = parallel
26 def get_translation_keys(self):
27 """
28 Get the translation keys for the parallel_info dictionary.
30 A translation key is specified in a config file with the syntax
31 `key_kwarg_trans = command, type`, e.g if `nprocs_kwarg_trans = -np`
32 is specified in the config file, then the key `nprocs` will be
33 translated to `-np`. Then `nprocs` can be specified in parallel_info
34 and will be translated to `-np` when the command is build.
36 Returns
37 -------
38 dict of iterable
39 Dictionary with translation keys where the keys are the keys in
40 parallel_info that will be translated, the value is what the key
41 will be translated into.
42 """
43 translation_keys = {}
44 for key, value in self.parallel_info.items():
45 if len(key) < 12:
46 continue
47 if key.endswith('_kwarg_trans'):
48 trans_key = key[:-12]
49 translation_keys[trans_key] = value
50 return translation_keys
52 def get_command(self, inputfile, calc_command=None) -> List[str]:
53 """
54 Get the command to run. This should be a list of strings.
56 Parameters
57 ----------
58 inputfile : str
59 calc_command: list[str]: calculator command (used for sockets)
61 Returns
62 -------
63 list of str
64 The command to run.
65 """
66 command = []
67 if self.parallel:
68 if 'binary' in self.parallel_info:
69 command.append(self.parallel_info['binary'])
71 translation_keys = self.get_translation_keys()
73 for key, value in self.parallel_info.items():
74 if key == 'binary' or '_kwarg_trans' in key:
75 continue
77 command_key = key
78 if key in translation_keys:
79 command_key = translation_keys[key]
81 if type(value) is not bool:
82 command.append(f'{command_key}')
83 command.append(f'{value}')
84 elif value:
85 command.append(f'{command_key}')
87 if calc_command is None:
88 command.extend(self.get_calculator_command(inputfile))
89 else:
90 command.extend(calc_command)
91 return command
93 @abstractmethod
94 def get_calculator_command(self, inputfile):
95 """
96 The calculator specific command as a list of strings.
98 Parameters
99 ----------
100 inputfile : str
102 Returns
103 -------
104 list of str
105 The command to run.
106 """
107 ...
109 def run(
110 self, directory: Path, inputfile: Optional[str],
111 outputfile: str, errorfile: Optional[str] = None,
112 append: bool = False
113 ) -> None:
114 """
115 Run the command in the given directory.
117 Parameters
118 ----------
119 directory : pathlib.Path
120 The directory to run the command in.
121 inputfile : Optional[str]
122 The name of the input file.
123 outputfile : str
124 The name of the output file.
125 errorfile: Optional[str]
126 the stderror file
127 append: bool
128 if True then use append mode
129 """
131 import os
132 from subprocess import check_call
134 argv_command = self.get_command(inputfile)
135 mode = 'wb' if not append else 'ab'
137 with ExitStack() as stack:
138 output_path = directory / outputfile
139 fd_out = stack.enter_context(open(output_path, mode))
140 if errorfile is not None:
141 error_path = directory / errorfile
142 fd_err = stack.enter_context(open(error_path, mode))
143 else:
144 fd_err = None
145 check_call(
146 argv_command,
147 cwd=directory,
148 stdout=fd_out,
149 stderr=fd_err,
150 env=os.environ,
151 )
153 @abstractmethod
154 def version(self):
155 """
156 Get the version of the code.
158 Returns
159 -------
160 str
161 The version of the code.
162 """
163 ...
165 @classmethod
166 def from_config(cls, cfg, section_name, parallel_info=None, parallel=True):
167 """
168 Create a profile from a configuration file.
170 Parameters
171 ----------
172 cfg : ase.config.Config
173 The configuration object.
174 section_name : str
175 The name of the section in the configuration file. E.g. the name
176 of the template that this profile is for.
178 Returns
179 -------
180 BaseProfile
181 The profile object.
182 """
183 parallel_config = dict(cfg.parser['parallel'])
184 parallel_info = parallel_info if parallel_info is not None else {}
185 parallel_config.update(parallel_info)
187 try:
188 return cls(
189 **cfg.parser[section_name],
190 parallel_info=parallel_config,
191 parallel=parallel,
192 )
193 except TypeError as err:
194 raise BadConfiguration(*err.args)
197class BadConfiguration(Exception):
198 pass
201def read_stdout(args, createfile=None):
202 """Run command in tempdir and return standard output.
204 Helper function for getting version numbers of DFT codes.
205 Most DFT codes don't implement a --version flag, so in order to
206 determine the code version, we just run the code until it prints
207 a version number."""
208 import tempfile
209 from subprocess import PIPE, Popen
211 with tempfile.TemporaryDirectory() as directory:
212 if createfile is not None:
213 path = Path(directory) / createfile
214 path.touch()
215 proc = Popen(
216 args,
217 stdout=PIPE,
218 stderr=PIPE,
219 stdin=PIPE,
220 cwd=directory,
221 encoding='utf-8', # Make this a parameter if any non-utf8/ascii
222 )
223 stdout, _ = proc.communicate()
224 # Exit code will be != 0 because there isn't an input file
225 return stdout
228class CalculatorTemplate(ABC):
229 def __init__(self, name: str, implemented_properties: Iterable[str]):
230 self.name = name
231 self.implemented_properties = frozenset(implemented_properties)
233 @abstractmethod
234 def write_input(self, profile, directory, atoms, parameters, properties):
235 ...
237 @abstractmethod
238 def execute(self, directory, profile):
239 ...
241 @abstractmethod
242 def read_results(self, directory: PathLike) -> Mapping[str, Any]:
243 ...
245 @abstractmethod
246 def load_profile(self, cfg, parallel_info=None, parallel=True):
247 ...
249 def socketio_calculator(
250 self,
251 profile,
252 parameters,
253 directory,
254 # We may need quite a few socket kwargs here
255 # if we want to expose all the timeout etc. from
256 # SocketIOCalculator.
257 unixsocket=None,
258 port=None,
259 ):
260 import os
261 from subprocess import Popen
263 from ase.calculators.socketio import SocketIOCalculator
265 if port and unixsocket:
266 raise TypeError(
267 'For the socketio_calculator only a UNIX '
268 '(unixsocket) or INET (port) socket can be used'
269 ' not both.'
270 )
272 if not port and not unixsocket:
273 raise TypeError(
274 'For the socketio_calculator either a '
275 'UNIX (unixsocket) or INET (port) socket '
276 'must be used'
277 )
279 if not (
280 hasattr(self, 'socketio_argv')
281 and hasattr(self, 'socketio_parameters')
282 ):
283 raise TypeError(
284 f'Template {self} does not implement mandatory '
285 'socketio_argv() and socketio_parameters()'
286 )
288 # XXX need socketio ABC or something
289 argv = profile.get_command(
290 inputfile=None,
291 calc_command=self.socketio_argv(profile, unixsocket, port)
292 )
293 parameters = {
294 **self.socketio_parameters(unixsocket, port),
295 **parameters,
296 }
298 # Not so elegant that socket args are passed to this function
299 # via socketiocalculator when we could make a closure right here.
300 def launch(atoms, properties, port, unixsocket):
301 directory.mkdir(exist_ok=True, parents=True)
303 self.write_input(
304 atoms=atoms,
305 profile=profile,
306 parameters=parameters,
307 properties=properties,
308 directory=directory,
309 )
311 with open(directory / self.outputname, 'w') as out_fd:
312 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ)
314 return SocketIOCalculator(
315 launch_client=launch, unixsocket=unixsocket, port=port
316 )
319class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin):
320 cfg = _cfg
322 def __init__(
323 self,
324 *,
325 template,
326 profile,
327 directory,
328 parameters=None,
329 parallel_info=None,
330 parallel=True,
331 ):
332 self.template = template
333 if profile is None:
334 if template.name not in self.cfg.parser:
335 raise EnvironmentError(f'No configuration of {template.name}')
336 try:
337 profile = template.load_profile(
338 self.cfg, parallel_info=parallel_info, parallel=parallel
339 )
340 except Exception as err:
341 configvars = self.cfg.as_dict()
342 raise EnvironmentError(
343 f'Failed to load section [{template.name}] '
344 f'from configuration: {configvars}'
345 ) from err
347 self.profile = profile
349 # Maybe we should allow directory to be a factory, so
350 # calculators e.g. produce new directories on demand.
351 self.directory = Path(directory)
352 super().__init__(parameters)
354 def set(self, *args, **kwargs):
355 raise RuntimeError(
356 'No setting parameters for now, please. '
357 'Just create new calculators.'
358 )
360 def __repr__(self):
361 return f'{type(self).__name__}({self.template.name})'
363 @property
364 def implemented_properties(self):
365 return self.template.implemented_properties
367 @property
368 def name(self):
369 return self.template.name
371 def write_inputfiles(self, atoms, properties):
372 # SocketIOCalculators like to write inputfiles
373 # without calculating.
374 self.directory.mkdir(exist_ok=True, parents=True)
375 self.template.write_input(
376 profile=self.profile,
377 atoms=atoms,
378 parameters=self.parameters,
379 properties=properties,
380 directory=self.directory,
381 )
383 def calculate(self, atoms, properties, system_changes):
384 self.write_inputfiles(atoms, properties)
385 self.template.execute(self.directory, self.profile)
386 self.results = self.template.read_results(self.directory)
387 # XXX Return something useful?
389 def _outputmixin_get_results(self):
390 return self.results
392 def socketio(self, **socketkwargs):
393 return self.template.socketio_calculator(
394 directory=self.directory,
395 parameters=self.parameters,
396 profile=self.profile,
397 **socketkwargs,
398 )