Coverage for /builds/hweiske/ase/ase/calculators/genericfileio.py: 88.16%

152 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-22 11:22 +0000

1from abc import ABC, abstractmethod 

2from contextlib import ExitStack 

3from os import PathLike 

4from pathlib import Path 

5from typing import Any, Iterable, List, Mapping, Optional 

6 

7from ase.calculators.abc import GetOutputsMixin 

8from ase.calculators.calculator import BaseCalculator, EnvironmentError 

9from ase.config import cfg as _cfg 

10 

11 

12class BaseProfile(ABC): 

13 def __init__(self, parallel=True, parallel_info=None): 

14 """ 

15 Parameters 

16 ---------- 

17 parallel : bool 

18 If the calculator should be run in parallel. 

19 parallel_info : dict 

20 Additional settings for parallel execution, e.g. arguments 

21 for the binary for parallelization (mpiexec, srun, mpirun). 

22 """ 

23 self.parallel_info = parallel_info or {} 

24 self.parallel = parallel 

25 

26 def get_translation_keys(self): 

27 """ 

28 Get the translation keys for the parallel_info dictionary. 

29 

30 A translation key is specified in a config file with the syntax 

31 `key_kwarg_trans = command, type`, e.g if `nprocs_kwarg_trans = -np` 

32 is specified in the config file, then the key `nprocs` will be 

33 translated to `-np`. Then `nprocs` can be specified in parallel_info 

34 and will be translated to `-np` when the command is build. 

35 

36 Returns 

37 ------- 

38 dict of iterable 

39 Dictionary with translation keys where the keys are the keys in 

40 parallel_info that will be translated, the value is what the key 

41 will be translated into. 

42 """ 

43 translation_keys = {} 

44 for key, value in self.parallel_info.items(): 

45 if len(key) < 12: 

46 continue 

47 if key.endswith('_kwarg_trans'): 

48 trans_key = key[:-12] 

49 translation_keys[trans_key] = value 

50 return translation_keys 

51 

52 def get_command(self, inputfile, calc_command=None) -> List[str]: 

53 """ 

54 Get the command to run. This should be a list of strings. 

55 

56 Parameters 

57 ---------- 

58 inputfile : str 

59 calc_command: list[str]: calculator command (used for sockets) 

60 

61 Returns 

62 ------- 

63 list of str 

64 The command to run. 

65 """ 

66 command = [] 

67 if self.parallel: 

68 if 'binary' in self.parallel_info: 

69 command.append(self.parallel_info['binary']) 

70 

71 translation_keys = self.get_translation_keys() 

72 

73 for key, value in self.parallel_info.items(): 

74 if key == 'binary' or '_kwarg_trans' in key: 

75 continue 

76 

77 command_key = key 

78 if key in translation_keys: 

79 command_key = translation_keys[key] 

80 

81 if type(value) is not bool: 

82 command.append(f'{command_key}') 

83 command.append(f'{value}') 

84 elif value: 

85 command.append(f'{command_key}') 

86 

87 if calc_command is None: 

88 command.extend(self.get_calculator_command(inputfile)) 

89 else: 

90 command.extend(calc_command) 

91 return command 

92 

93 @abstractmethod 

94 def get_calculator_command(self, inputfile): 

95 """ 

96 The calculator specific command as a list of strings. 

97 

98 Parameters 

99 ---------- 

100 inputfile : str 

101 

102 Returns 

103 ------- 

104 list of str 

105 The command to run. 

106 """ 

107 ... 

108 

109 def run( 

110 self, directory: Path, inputfile: Optional[str], 

111 outputfile: str, errorfile: Optional[str] = None, 

112 append: bool = False 

113 ) -> None: 

114 """ 

115 Run the command in the given directory. 

116 

117 Parameters 

118 ---------- 

119 directory : pathlib.Path 

120 The directory to run the command in. 

121 inputfile : Optional[str] 

122 The name of the input file. 

123 outputfile : str 

124 The name of the output file. 

125 errorfile: Optional[str] 

126 the stderror file 

127 append: bool 

128 if True then use append mode 

129 """ 

130 

131 import os 

132 from subprocess import check_call 

133 

134 argv_command = self.get_command(inputfile) 

135 mode = 'wb' if not append else 'ab' 

136 

137 with ExitStack() as stack: 

138 output_path = directory / outputfile 

139 fd_out = stack.enter_context(open(output_path, mode)) 

140 if errorfile is not None: 

141 error_path = directory / errorfile 

142 fd_err = stack.enter_context(open(error_path, mode)) 

143 else: 

144 fd_err = None 

145 check_call( 

146 argv_command, 

147 cwd=directory, 

148 stdout=fd_out, 

149 stderr=fd_err, 

150 env=os.environ, 

151 ) 

152 

153 @abstractmethod 

154 def version(self): 

155 """ 

156 Get the version of the code. 

157 

158 Returns 

159 ------- 

160 str 

161 The version of the code. 

162 """ 

163 ... 

164 

165 @classmethod 

166 def from_config(cls, cfg, section_name, parallel_info=None, parallel=True): 

167 """ 

168 Create a profile from a configuration file. 

169 

170 Parameters 

171 ---------- 

172 cfg : ase.config.Config 

173 The configuration object. 

174 section_name : str 

175 The name of the section in the configuration file. E.g. the name 

176 of the template that this profile is for. 

177 

178 Returns 

179 ------- 

180 BaseProfile 

181 The profile object. 

182 """ 

183 parallel_config = dict(cfg.parser['parallel']) 

184 parallel_info = parallel_info if parallel_info is not None else {} 

185 parallel_config.update(parallel_info) 

186 

187 try: 

188 return cls( 

189 **cfg.parser[section_name], 

190 parallel_info=parallel_config, 

191 parallel=parallel, 

192 ) 

193 except TypeError as err: 

194 raise BadConfiguration(*err.args) 

195 

196 

197class BadConfiguration(Exception): 

198 pass 

199 

200 

201def read_stdout(args, createfile=None): 

202 """Run command in tempdir and return standard output. 

203 

204 Helper function for getting version numbers of DFT codes. 

205 Most DFT codes don't implement a --version flag, so in order to 

206 determine the code version, we just run the code until it prints 

207 a version number.""" 

208 import tempfile 

209 from subprocess import PIPE, Popen 

210 

211 with tempfile.TemporaryDirectory() as directory: 

212 if createfile is not None: 

213 path = Path(directory) / createfile 

214 path.touch() 

215 proc = Popen( 

216 args, 

217 stdout=PIPE, 

218 stderr=PIPE, 

219 stdin=PIPE, 

220 cwd=directory, 

221 encoding='utf-8', # Make this a parameter if any non-utf8/ascii 

222 ) 

223 stdout, _ = proc.communicate() 

224 # Exit code will be != 0 because there isn't an input file 

225 return stdout 

226 

227 

228class CalculatorTemplate(ABC): 

229 def __init__(self, name: str, implemented_properties: Iterable[str]): 

230 self.name = name 

231 self.implemented_properties = frozenset(implemented_properties) 

232 

233 @abstractmethod 

234 def write_input(self, profile, directory, atoms, parameters, properties): 

235 ... 

236 

237 @abstractmethod 

238 def execute(self, directory, profile): 

239 ... 

240 

241 @abstractmethod 

242 def read_results(self, directory: PathLike) -> Mapping[str, Any]: 

243 ... 

244 

245 @abstractmethod 

246 def load_profile(self, cfg, parallel_info=None, parallel=True): 

247 ... 

248 

249 def socketio_calculator( 

250 self, 

251 profile, 

252 parameters, 

253 directory, 

254 # We may need quite a few socket kwargs here 

255 # if we want to expose all the timeout etc. from 

256 # SocketIOCalculator. 

257 unixsocket=None, 

258 port=None, 

259 ): 

260 import os 

261 from subprocess import Popen 

262 

263 from ase.calculators.socketio import SocketIOCalculator 

264 

265 if port and unixsocket: 

266 raise TypeError( 

267 'For the socketio_calculator only a UNIX ' 

268 '(unixsocket) or INET (port) socket can be used' 

269 ' not both.' 

270 ) 

271 

272 if not port and not unixsocket: 

273 raise TypeError( 

274 'For the socketio_calculator either a ' 

275 'UNIX (unixsocket) or INET (port) socket ' 

276 'must be used' 

277 ) 

278 

279 if not ( 

280 hasattr(self, 'socketio_argv') 

281 and hasattr(self, 'socketio_parameters') 

282 ): 

283 raise TypeError( 

284 f'Template {self} does not implement mandatory ' 

285 'socketio_argv() and socketio_parameters()' 

286 ) 

287 

288 # XXX need socketio ABC or something 

289 argv = profile.get_command( 

290 inputfile=None, 

291 calc_command=self.socketio_argv(profile, unixsocket, port) 

292 ) 

293 parameters = { 

294 **self.socketio_parameters(unixsocket, port), 

295 **parameters, 

296 } 

297 

298 # Not so elegant that socket args are passed to this function 

299 # via socketiocalculator when we could make a closure right here. 

300 def launch(atoms, properties, port, unixsocket): 

301 directory.mkdir(exist_ok=True, parents=True) 

302 

303 self.write_input( 

304 atoms=atoms, 

305 profile=profile, 

306 parameters=parameters, 

307 properties=properties, 

308 directory=directory, 

309 ) 

310 

311 with open(directory / self.outputname, 'w') as out_fd: 

312 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ) 

313 

314 return SocketIOCalculator( 

315 launch_client=launch, unixsocket=unixsocket, port=port 

316 ) 

317 

318 

319class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin): 

320 cfg = _cfg 

321 

322 def __init__( 

323 self, 

324 *, 

325 template, 

326 profile, 

327 directory, 

328 parameters=None, 

329 parallel_info=None, 

330 parallel=True, 

331 ): 

332 self.template = template 

333 if profile is None: 

334 if template.name not in self.cfg.parser: 

335 raise EnvironmentError(f'No configuration of {template.name}') 

336 try: 

337 profile = template.load_profile( 

338 self.cfg, parallel_info=parallel_info, parallel=parallel 

339 ) 

340 except Exception as err: 

341 configvars = self.cfg.as_dict() 

342 raise EnvironmentError( 

343 f'Failed to load section [{template.name}] ' 

344 f'from configuration: {configvars}' 

345 ) from err 

346 

347 self.profile = profile 

348 

349 # Maybe we should allow directory to be a factory, so 

350 # calculators e.g. produce new directories on demand. 

351 self.directory = Path(directory) 

352 super().__init__(parameters) 

353 

354 def set(self, *args, **kwargs): 

355 raise RuntimeError( 

356 'No setting parameters for now, please. ' 

357 'Just create new calculators.' 

358 ) 

359 

360 def __repr__(self): 

361 return f'{type(self).__name__}({self.template.name})' 

362 

363 @property 

364 def implemented_properties(self): 

365 return self.template.implemented_properties 

366 

367 @property 

368 def name(self): 

369 return self.template.name 

370 

371 def write_inputfiles(self, atoms, properties): 

372 # SocketIOCalculators like to write inputfiles 

373 # without calculating. 

374 self.directory.mkdir(exist_ok=True, parents=True) 

375 self.template.write_input( 

376 profile=self.profile, 

377 atoms=atoms, 

378 parameters=self.parameters, 

379 properties=properties, 

380 directory=self.directory, 

381 ) 

382 

383 def calculate(self, atoms, properties, system_changes): 

384 self.write_inputfiles(atoms, properties) 

385 self.template.execute(self.directory, self.profile) 

386 self.results = self.template.read_results(self.directory) 

387 # XXX Return something useful? 

388 

389 def _outputmixin_get_results(self): 

390 return self.results 

391 

392 def socketio(self, **socketkwargs): 

393 return self.template.socketio_calculator( 

394 directory=self.directory, 

395 parameters=self.parameters, 

396 profile=self.profile, 

397 **socketkwargs, 

398 )