Running without AiiDA
In [24]:
Copied!
import yaml
import json
from xtalpaint.inpainting.config_schema import InpaintingWorkGraphConfig
from xtalpaint.data import BatchedStructures
from xtalpaint.utils.relaxation_utils import relax_structures
from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor
from IPython.display import clear_output
import yaml
import json
from xtalpaint.inpainting.config_schema import InpaintingWorkGraphConfig
from xtalpaint.data import BatchedStructures
from xtalpaint.utils.relaxation_utils import relax_structures
from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor
from IPython.display import clear_output
In [2]:
Copied!
input_structures = read("test-structures.extxyz", index=':')
input_structures = {
a.info['uuid'].replace("-", "_"): AseAtomsAdaptor.get_structure(a) for a in input_structures
}
input_structures = read("test-structures.extxyz", index=':')
input_structures = {
a.info['uuid'].replace("-", "_"): AseAtomsAdaptor.get_structure(a) for a in input_structures
}
In [3]:
Copied!
len(input_structures)
len(input_structures)
Out[3]:
5
In [4]:
Copied!
param_grid = {
"N_steps": 5,
"coordinates_snr": 0.2,
"n_corrector_steps": 1,
"batch_size": 1000,
}
param_grid = {
"N_steps": 5,
"coordinates_snr": 0.2,
"n_corrector_steps": 1,
"batch_size": 1000,
}
In [16]:
Copied!
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/dev-mattergen-inpainting/bin/activate"
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/test-xtalpaint/bin/activate"
inputs = InpaintingWorkGraphConfig(
inpainting_pipeline_params={
"record_trajectories": False,
"predictor_corrector": "baseline",
"inpainting_model_params": param_grid,
# "pretrained_name": "mattergen_base",
'model_path': '/home/reents_t/project/test-xtalpaint/new-td-20-perc',
"sampling_config_path": "/home/reents_t/project/test-xtalpaint/git/mattergen/sampling_conf",
},
structures=BatchedStructures(
{k.replace("-", "_"): s for k, s in input_structures.items()}
),
gen_inpainting_candidates_params={
"n_inp": {
k.replace("-", "_"): int(s.composition["H"])
for k, s in input_structures.items()
},
"element": "H",
"num_samples": 1,
},
relax=True,
full_relax=True,
full_relax_wo_pre_relax=False,
relax_kwargs={
"elements_to_relax": ["H"],
"fmax": 0.01,
"max_natoms_per_batch": 5000,
"load_path": "MatterSim-v1.0.0-5M.pth",
"max_n_steps": 50,
"device": "cuda",
"mlip": "mattersim",
"optimizer": "BFGS",
"return_initial_energies": False,
"return_initial_forces": False,
"return_final_forces": False,
},
gen_inpainting_candidates_options={
"custom_scheduler_commands": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
},
options={
"prepend_text": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
},
evaluate_params={"max_workers": 5, "metrics": ["match", "rmsd"]},
evaluate=True,
)
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/dev-mattergen-inpainting/bin/activate"
ENV_ACTIVATION_CMD = "source ~/.aiida_venvs/test-xtalpaint/bin/activate"
inputs = InpaintingWorkGraphConfig(
inpainting_pipeline_params={
"record_trajectories": False,
"predictor_corrector": "baseline",
"inpainting_model_params": param_grid,
# "pretrained_name": "mattergen_base",
'model_path': '/home/reents_t/project/test-xtalpaint/new-td-20-perc',
"sampling_config_path": "/home/reents_t/project/test-xtalpaint/git/mattergen/sampling_conf",
},
structures=BatchedStructures(
{k.replace("-", "_"): s for k, s in input_structures.items()}
),
gen_inpainting_candidates_params={
"n_inp": {
k.replace("-", "_"): int(s.composition["H"])
for k, s in input_structures.items()
},
"element": "H",
"num_samples": 1,
},
relax=True,
full_relax=True,
full_relax_wo_pre_relax=False,
relax_kwargs={
"elements_to_relax": ["H"],
"fmax": 0.01,
"max_natoms_per_batch": 5000,
"load_path": "MatterSim-v1.0.0-5M.pth",
"max_n_steps": 50,
"device": "cuda",
"mlip": "mattersim",
"optimizer": "BFGS",
"return_initial_energies": False,
"return_initial_forces": False,
"return_final_forces": False,
},
gen_inpainting_candidates_options={
"custom_scheduler_commands": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
},
options={
"prepend_text": f"{ENV_ACTIVATION_CMD}\nexport PYTHONBREAKPOINT=0",
},
evaluate_params={"max_workers": 5, "metrics": ["match", "rmsd"]},
evaluate=True,
)
Running the inpainting workflow without AiiDA¶
In [17]:
Copied!
print(f"Processing {len(input_structures)} structures")
print(f"Processing {len(input_structures)} structures")
Processing 5 structures
Generate inpainting candidates¶
In [18]:
Copied!
from xtalpaint.inpainting.generate_candidates import (
generate_inpainting_candidates,
)
from xtalpaint.inpainting.generate_candidates import (
generate_inpainting_candidates,
)
In [19]:
Copied!
# Step 1: Generate inpainting candidates
print("Running inpainting pipeline...")
n_inp_dict = inputs.gen_inpainting_candidates_params.n_inp
element = inputs.gen_inpainting_candidates_params.element
num_samples = inputs.gen_inpainting_candidates_params.num_samples
inpainting_candidates = generate_inpainting_candidates(
structures=input_structures,
n_inp=n_inp_dict,
element=element,
num_samples=num_samples,
)
print(f"Generated {len(inpainting_candidates)} inpainted structures")
# Step 1: Generate inpainting candidates
print("Running inpainting pipeline...")
n_inp_dict = inputs.gen_inpainting_candidates_params.n_inp
element = inputs.gen_inpainting_candidates_params.element
num_samples = inputs.gen_inpainting_candidates_params.num_samples
inpainting_candidates = generate_inpainting_candidates(
structures=input_structures,
n_inp=n_inp_dict,
element=element,
num_samples=num_samples,
)
print(f"Generated {len(inpainting_candidates)} inpainted structures")
Running inpainting pipeline... Generated 5 inpainted structures
In [20]:
Copied!
inpainting_candidates
inpainting_candidates
Out[20]:
{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
Lattice
abc : 5.169783747 5.169783747 5.169783747
angles : 90.0 90.0 90.0
volume : 138.17107311088554
A : 5.169783747 0.0 0.0
B : 0.0 5.169783747 0.0
C : 0.0 0.0 5.169783747
pbc : True True True
PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan],
'47b9a869_9b1e_438b_8c93_f5ac654bfdd8': Structure Summary
Lattice
abc : 4.401448665821795 4.401448665821795 7.1759779644947
angles : 90.0 90.0 120.19000916247398
volume : 120.16231769034825
A : 2.1944008576909 -3.8154102313683 0.0
B : 2.1944008576909 3.8154102313683 0.0
C : 0.0 0.0 7.1759779644947
pbc : True True True
PeriodicSite: O (2.194, 2.549, 3.886) [0.1659, 0.8341, 0.5415]
PeriodicSite: O (7.744e-18, 1.266, 0.2978) [-0.1659, 0.1659, 0.0415]
PeriodicSite: O (2.194, 2.544, 6.579) [0.1666, 0.8334, 0.9168]
PeriodicSite: O (1.254e-17, 1.272, 2.991) [-0.1666, 0.1666, 0.4168]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan],
'662c7351_ee76_48ea_bab7_b733e1fdf607': Structure Summary
Lattice
abc : 8.414296245395509 8.414296245395509 8.414296245395509
angles : 93.51670444494472 93.51670444494472 151.3208570539183
volume : 276.98877799736266
A : -5.7644391936362 5.7644391936362 2.0839536633569
B : 5.7644391936362 -5.7644391936362 2.0839536633569
C : 5.7644391936362 5.7644391936362 -2.0839536633569
pbc : True True True
PeriodicSite: O (4.413e-16, 8.82, -0.1928) [0.7188, -0.04627, 0.7651]
PeriodicSite: O (7.129e-17, 3.056, 1.235) [0.5613, 0.2963, 0.2651]
PeriodicSite: O (7.664e-17, 2.709, -0.1928) [0.1887, -0.04627, 0.2349]
PeriodicSite: O (7.102e-16, 8.473, 1.235) [1.031, 0.2963, 0.7349]
PeriodicSite: O (3.056, 5.764, 2.277) [1.046, 0.8113, 0.7651]
PeriodicSite: O (8.82, 5.764, 0.8491) [0.7037, 0.9688, 1.265]
PeriodicSite: O (8.473, 5.764, 2.277) [1.046, 1.281, 1.235]
PeriodicSite: O (2.709, 5.764, 0.8491) [0.7037, 0.4387, 0.7349]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan],
'7fa282c5_4971_46f4_8b3b_776595a0fa06': Structure Summary
Lattice
abc : 6.3321144796 6.332114479595774 6.603633127
angles : 90.0 90.0 120.00000000002208
volume : 229.3037119499634
A : 6.3321144796 0.0 0.0
B : -3.1660572398 5.483771999 0.0
C : 0.0 0.0 6.603633127
pbc : True True True
PeriodicSite: Y (2.099, 3.636, 1.651) [0.6631, 0.6631, 0.25]
PeriodicSite: Y (4.199, 0.0, 4.953) [0.6631, 0.0, 0.75]
PeriodicSite: Y (-2.099, 3.636, 4.953) [-6.806e-10, 0.6631, 0.75]
PeriodicSite: Y (-1.067, 1.848, 1.651) [7.402e-10, 0.3369, 0.25]
PeriodicSite: Y (1.067, 1.848, 4.953) [0.3369, 0.3369, 0.75]
PeriodicSite: Y (2.134, 0.0, 1.651) [0.3369, 0.0, 0.25]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan],
'c436bbf4_9aef_44a8_8960_00227f79a32f': Structure Summary
Lattice
abc : 4.9620535714985 5.689428317103 4.4034011624103
angles : 90.0 90.0 90.0
volume : 124.31351070276466
A : 4.9620535714985 0.0 0.0
B : 0.0 5.689428317103 0.0
C : 0.0 0.0 4.4034011624103
pbc : True True True
PeriodicSite: V (2.481, 1.451, 2.02) [0.5, 0.2551, 0.4587]
PeriodicSite: V (-2.13e-06, 4.238, 2.02) [-4.293e-07, 0.7449, 0.4587]
PeriodicSite: V (2.481, 4.238, 2.02) [0.5, 0.7449, 0.4587]
PeriodicSite: V (2.13e-06, 1.451, 2.02) [4.293e-07, 0.2551, 0.4587]
PeriodicSite: V (2.481, 2.845, -0.1521) [0.5, 0.5, -0.03455]
PeriodicSite: V (0.0, 2.845, -0.1521) [0.0, 0.5, -0.03455]
PeriodicSite: V (2.481, 0.0, -0.1605) [0.5, 0.0, -0.03644]
PeriodicSite: V (0.0, 0.0, -0.1605) [0.0, 0.0, -0.03644]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]
PeriodicSite: H (nan, nan, nan) [nan, nan, nan]}
Run inpainting¶
In [21]:
Copied!
from xtalpaint.inpainting.inpainting_process import (
run_inpainting_pipeline,
run_mpi_parallel_inpainting_pipeline,
)
from xtalpaint.inpainting.inpainting_process import (
run_inpainting_pipeline,
run_mpi_parallel_inpainting_pipeline,
)
In [22]:
Copied!
USE_MPI_FOR_PARALLEL_INPAINTING = False
inpainting_method = (
run_mpi_parallel_inpainting_pipeline
if USE_MPI_FOR_PARALLEL_INPAINTING
else run_inpainting_pipeline
)
USE_MPI_FOR_PARALLEL_INPAINTING = False
inpainting_method = (
run_mpi_parallel_inpainting_pipeline
if USE_MPI_FOR_PARALLEL_INPAINTING
else run_inpainting_pipeline
)
In [23]:
Copied!
config = inputs.inpainting_pipeline_params.model_dump(
exclude_none=True
)
inpainting_outputs = run_inpainting_pipeline(
structures=inpainting_candidates, config=config
)
clear_output()
config = inputs.inpainting_pipeline_params.model_dump(
exclude_none=True
)
inpainting_outputs = run_inpainting_pipeline(
structures=inpainting_candidates, config=config
)
clear_output()
Converting structures to numpy: 0%| | 0/5 [00:00<?, ?it/s]
INFO:mattergen.common.utils.eval_utils:Loading model from checkpoint: /home/reents_t/project/test-xtalpaint/new-td-20-perc/checkpoints/last.ckpt
Model config:
auto_resume: false
checkpoint_path: null
data_module:
_recursive_: true
_target_: mattergen.common.data.datamodule.CrystDataModule
average_density: 0.05771451654022283
batch_size:
train: 128
val: 128
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
max_epochs: 2200
num_workers:
train: 128
val: 128
properties: []
root_dir: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H
train_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/train
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties: []
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
val_dataset:
_target_: mattergen.common.data.dataset.CrystalDataset.from_cache_path
cache_path: /data/user/reents_t/projects/mlip/git/mattergen/mattergen/../datasets/cache/alex_mp_20_wo_mc3d_H/val
dataset_transforms:
- _partial_: true
_target_: mattergen.common.data.dataset_transform.filter_sparse_properties
properties: []
transforms:
- _partial_: true
_target_: mattergen.common.data.transform.symmetrize_lattice
- _partial_: true
_target_: mattergen.common.data.transform.set_chemical_system_string
lightning_module:
_target_: mattergen.diffusion.lightning_module.DiffusionLightningModule
diffusion_module:
_target_: xtalpaint.time_dependent.diffusion_module.TDDiffusionModule
corruption:
_target_: mattergen.diffusion.corruption.multi_corruption.MultiCorruption
sdes:
pos:
_target_: xtalpaint.time_dependent.corruption.TDNumAtomsVarianceAdjustedWrappedVESDE
limit_info_key: num_atoms
sigma_max: 5.0
wrapping_boundary: 1.0
loss_fn:
_target_: xtalpaint.time_dependent.loss.TDMaterialsLoss
d3pm_hybrid_lambda: 0.01
include_atomic_numbers: false
include_cell: false
include_pos: true
reduce: sum
weights:
pos: 1
model:
_target_: mattergen.denoiser.GemNetTDenoiser
atom_type_diffusion: mask
denoise_atom_types: false
gemnet:
_target_: xtalpaint.time_dependent.gemnet.TDGemNetT
atom_embedding:
_target_: mattergen.common.gemnet.layers.embedding_block.AtomEmbedding
emb_size: 512
with_mask_type: false
cutoff: 7.0
emb_size_atom: 512
emb_size_edge: 512
latent_dim: 512
max_cell_images_per_dim: 5
max_neighbors: 50
num_blocks: 4
num_targets: 1
otf_graph: true
regress_stress: true
scale_file: /home/reents_t/project/test-xtalpaint/git/mattergen/mattergen/common/gemnet/gemnet-dT.json
hidden_dim: 512
property_embeddings: {}
property_embeddings_adapt: {}
p_replace: 0.2
pre_corruption_fn:
_target_: mattergen.property_embeddings.SetEmbeddingType
dropout_fields_iid: false
p_unconditional: 0.2
t_replace: 0.001
optimizer_partial:
_partial_: true
_target_: torch.optim.Adam
lr: 0.0001
scheduler_partials:
- frequency: 1
interval: epoch
monitor: loss_train
scheduler:
_partial_: true
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
min_lr: 1.0e-06
patience: 100
verbose: true
strict: true
load_original: false
params: {}
trainer:
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
callbacks:
- _target_: pytorch_lightning.callbacks.EarlyStopping
min_delta: 0.01
mode: min
monitor: loss_val
patience: 20
strict: true
verbose: true
check_val_every_n_epoch: 5
devices: 4
gradient_clip_algorithm: value
gradient_clip_val: 0.5
max_epochs: 2200
num_nodes: 1
precision: 32
strategy:
_target_: pytorch_lightning.strategies.ddp.DDPStrategy
find_unused_parameters: true
Sampling config:
sampler_partial:
_target_: mattergen.diffusion.sampling.classifier_free_guidance.GuidedPredictorCorrector.from_pl_module
'N': 5
eps_t: 0.2
_partial_: true
guidance_scale: 0.0
remove_conditioning_fn:
_target_: mattergen.property_embeddings.SetUnconditionalEmbeddingType
keep_conditioning_fn:
_target_: mattergen.property_embeddings.SetConditionalEmbeddingType
predictor_partials:
pos:
_target_: mattergen.diffusion.wrapped.wrapped_predictors_correctors.WrappedAncestralSamplingPredictor
_partial_: true
corrector_partials:
pos:
_target_: mattergen.diffusion.wrapped.wrapped_predictors_correctors.WrappedLangevinCorrector
_partial_: true
max_step_size: 1000000.0
snr: 0.2
n_steps_corrector: 1
condition_loader_partial:
_partial_: true
_target_: mattergen.common.data.condition_factory.get_number_of_atoms_condition_loader
num_atoms_distribution: ALEX_MP_20
batch_size: 10
num_samples: 10
{'pos': <xtalpaint.time_dependent.corruption.TDNumAtomsVarianceAdjustedWrappedVESDE object at 0x7da1cda556f0>}
Generating samples: 0%| | 0/1 [00:00<?, ?it/s]WARNING:root:Warning: batch shape is == x shape, are you trying to expand something that is already expanded?
0%| | 0/5 [00:00<?, ?it/s]
Generating samples: 0%| | 0/1 [00:00<?, ?it/s]
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[23], line 5 1 config = inputs.inpainting_pipeline_params.model_dump( 2 exclude_none=True 3 ) ----> 5 inpainting_outputs = run_inpainting_pipeline( 6 structures=inpainting_candidates, config=config 7 ) 9 clear_output() File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/inpainting/inpainting_process.py:329, in run_inpainting_pipeline(structures, config) 322 labels, structures = map(list, zip(*structures.items())) 324 prepared_structures = _prepare_structures( 325 structures, 326 batch_size=config["inpainting_model_params"].get("batch_size", 64), 327 ) --> 329 inpainted_structures, trajectories, mean_trajectories = _run_inpainting( 330 structures_dl=prepared_structures, **config 331 ) 333 return _extract_outputs( 334 inpainted_structures, 335 trajectories, (...) 338 config["record_trajectories"], 339 ) File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/inpainting/inpainting_process.py:204, in _run_inpainting(predictor_corrector, structures_dl, inpainting_model_params, fix_cell, record_trajectories, pretrained_name, model_path, sampling_config_path) 183 """Run the inpainting process using MatterGen. 184 185 Args: (...) 198 mean_trajectories is None if not recorded. 199 """ 200 sampling_config_overrides, config_overrides = _get_overrides( 201 inpainting_model_params, predictor_corrector, fix_cell, pretrained_name 202 ) --> 204 reconstructed_structures = generate_reconstructed_structures( 205 structures_to_reconstruct=structures_dl, 206 sampling_config_overrides=sampling_config_overrides, 207 config_overrides=config_overrides, 208 model_path=model_path, 209 pretrained_name=pretrained_name, 210 record_trajectories=record_trajectories, 211 fix_cell=fix_cell, 212 sampling_config_path=sampling_config_path, 213 ) 215 if len(reconstructed_structures) == 2: 216 print("Not returning mean trajectories.") File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/generate_inpainting.py:298, in generate_reconstructed_structures(structures_to_reconstruct, pretrained_name, model_path, batch_size, num_batches, config_overrides, checkpoint_epoch, properties_to_condition_on, sampling_config_path, sampling_config_name, sampling_config_overrides, record_trajectories, diffusion_guidance_factor, strict_checkpoint_loading, target_compositions, fix_cell) 274 _sampling_config_path = ( 275 Path(sampling_config_path) 276 if sampling_config_path is not None 277 else None 278 ) 280 generator = CrystalInpaintingGenerator( 281 dataloader=structures_to_reconstruct, 282 checkpoint_info=checkpoint_info, (...) 295 target_compositions_dict=target_compositions, 296 ) --> 298 return generator.generate(fix_cell=fix_cell) File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/generate_inpainting.py:178, in CrystalInpaintingGenerator.generate(self, batch_size, num_batches, target_compositions_dict, fix_cell) 174 sampler._multi_corruption.corruptions.pop("atomic_numbers", None) 176 print(sampler.diffusion_module.corruption.corruptions) --> 178 generated_structures = draw_samples_from_sampler( 179 sampler=sampler, 180 condition_loader=condition_loader, 181 properties_to_condition_on=self.properties_to_condition_on, 182 record_trajectories=self.record_trajectories, 183 fix_cell=fix_cell, 184 ) 186 return generated_structures File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/generate_inpainting.py:85, in draw_samples_from_sampler(sampler, condition_loader, properties_to_condition_on, record_trajectories, fix_cell) 79 all_trajs_list.extend( 80 list_of_time_steps_to_list_of_trajectories( 81 intermediate_samples 82 ) 83 ) 84 else: ---> 85 sample, mean = sampler.sample(conditioning_data, mask) 86 all_samples_list.extend(mean.to_data_list()) 87 all_samples = collate(all_samples_list) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/pc_sampler.py:114, in PredictorCorrector.sample(self, conditioning_data, mask) 100 @torch.no_grad() 101 def sample( 102 self, conditioning_data: BatchedData, mask: Mapping[str, torch.Tensor] | None = None 103 ) -> SampleAndMean: 104 """Create one sample for each of a batch of conditions. 105 Args: 106 conditioning_data: batched conditioning data. Even if you think you don't want conditioning, you still need to pass a batch of conditions (...) 112 113 """ --> 114 return self._sample_maybe_record(conditioning_data, mask=mask, record=False)[:2] File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/pc_sampler.py:157, in PredictorCorrector._sample_maybe_record(self, conditioning_data, mask, record) 155 mask = {k: v.to(self._device) for k, v in mask.items()} 156 batch = _sample_prior(self._multi_corruption, conditioning_data, mask=mask) --> 157 return self._denoise(batch=batch, mask=mask, record=record) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/pc_sampler.py:187, in PredictorCorrector._denoise(self, batch, mask, record) 185 if self._correctors: 186 for _ in range(self._n_steps_corrector): --> 187 score = self._score_fn(batch, t) 188 fns = { 189 k: corrector.step_given_score for k, corrector in self._correctors.items() 190 } 191 samples_means: dict[str, Tuple[torch.Tensor, torch.Tensor]] = apply( 192 fns=fns, 193 broadcast={"t": t, "dt": dt}, (...) 196 batch_idx=self._multi_corruption._get_batch_indices(batch), 197 ) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/classifier_free_guidance.py:72, in GuidedPredictorCorrector._score_fn(self, x, t) 70 return get_conditional_score() 71 elif abs(self._guidance_scale) < 1e-15: ---> 72 return get_unconditional_score() 73 else: 74 # guided_score = guidance_factor * conditional_score + (1-guidance_factor) * unconditional_score 75 batch_no_condition = self._remove_conditioning_fn(x) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/classifier_free_guidance.py:60, in GuidedPredictorCorrector._score_fn.<locals>.get_unconditional_score() 59 def get_unconditional_score(): ---> 60 return super(GuidedPredictorCorrector, self)._score_fn( 61 x=self._remove_conditioning_fn(x), t=t 62 ) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/sampling/pc_sampler.py:94, in PredictorCorrector._score_fn(self, x, t) 93 def _score_fn(self, x: Diffusable, t: torch.Tensor) -> Diffusable: ---> 94 return self._diffusion_module.score_fn(x, t) File ~/project/test-xtalpaint/mattergen-clean/mattergen/diffusion/diffusion_module.py:129, in DiffusionModule.score_fn(self, x, t) 119 def score_fn(self, x: T, t: torch.Tensor) -> T: 120 """Calculate the score of a batch of data at a given timestep 121 122 Args: (...) 127 score: score of the batch of data at the given timestep 128 """ --> 129 model_out: T = self.model(x, t) 130 fns = {k: convert_model_out_to_score for k in self.corruption.sdes.keys()} 132 scores = apply( 133 fns=fns, 134 model_out=model_out, (...) 138 batch_idx=self.corruption._get_batch_indices(x), 139 ) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/project/test-xtalpaint/mattergen-clean/mattergen/denoiser.py:248, in GemNetTDenoiser.forward(self, x, t) 245 if len(property_embedding_values) > 0: 246 z_per_crystal = torch.cat([z_per_crystal, property_embedding_values], dim=-1) --> 248 output = self.gemnet( 249 z=z_per_crystal, 250 frac_coords=frac_coords, 251 atom_types=atom_types, 252 num_atoms=num_atoms, 253 batch=batch, 254 lengths=None, 255 angles=None, 256 lattice=lattice, 257 # we construct the graph on the fly, hence pass None for these: 258 edge_index=None, 259 to_jimages=None, 260 num_bonds=None, 261 ) 262 pred_atom_types = self.fc_atom(output.node_embeddings) 264 return get_chemgraph_from_denoiser_output( 265 pred_atom_types=pred_atom_types, 266 pred_lattice_eps=output.stress, (...) 270 x_input=x, 271 ) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs) 1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(*args, **kwargs) File ~/.aiida_venvs/test-xtalpaint/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(*args, **kwargs) 1522 try: 1523 result = None File ~/project/test-xtalpaint/git/XtalPaint/src/xtalpaint/time_dependent/gemnet.py:101, in TDGemNetT.forward(self, z, frac_coords, atom_types, num_atoms, batch, lengths, angles, edge_index, to_jimages, num_bonds, lattice) 99 if z is not None: 100 if z.shape[0] != h.shape[0]: --> 101 raise ValueError( 102 'The TD-Paint compatible GemNetT model expects the latent ' 103 'vector z to have the same first dimension as the number ' 104 'of atoms.' 105 ) 106 # Keep this only to emphasize the difference to the original GemNetT 107 z_per_atom = z ValueError: The TD-Paint compatible GemNetT model expects the latent vector z to have the same first dimension as the number of atoms.
In [ ]:
Copied!
from xtalpaint.time_dependent.gemnet import TDGemNetT
from xtalpaint.time_dependent.gemnet import TDGemNetT
In [14]:
Copied!
inpainting_outputs['structures'].get_structures(strct_type='pymatgen')
inpainting_outputs['structures'].get_structures(strct_type='pymatgen')
Out[14]:
{'20acc66e_8e38_4e5e_9e7a_c2400262cdc8': Structure Summary
Lattice
abc : 5.1697835922241255 5.1697835922241255 5.169783592224121
angles : 90.00000250447799 90.00000250447799 90.00000250447816
volume : 138.171060700957
A : 5.169783592224121 0.0 -2.2597841109472938e-07
B : -2.2597843664143413e-07 5.169783592224116 -2.2597841109472938e-07
C : 0.0 0.0 5.169783592224121
pbc : True True True
PeriodicSite: N (2.309, 4.894, 2.86) [0.4467, 0.9467, 0.5533]
PeriodicSite: N (4.894, 2.86, 2.309) [0.9467, 0.5533, 0.4467]
PeriodicSite: N (2.86, 2.309, 4.894) [0.5533, 0.4467, 0.9467]
PeriodicSite: N (0.2756, 0.2756, 0.2756) [0.0533, 0.0533, 0.0533]
PeriodicSite: H (1.0, 3.885, 2.78) [0.1935, 0.7514, 0.5377]
PeriodicSite: H (3.033, 0.5016, 4.072) [0.5867, 0.09703, 0.7876]
PeriodicSite: H (0.5171, 2.514, 3.879) [0.1, 0.4863, 0.7503]
PeriodicSite: H (2.79, 0.9618, 2.302) [0.5397, 0.186, 0.4453]
PeriodicSite: H (0.08463, 3.354, 2.823) [0.01637, 0.6488, 0.546]
PeriodicSite: H (4.048, 4.275, 2.766) [0.7831, 0.8269, 0.535]
PeriodicSite: H (3.573, 0.8306, 2.443) [0.6912, 0.1607, 0.4725]
PeriodicSite: H (1.074, 2.255, 2.883) [0.2078, 0.4361, 0.5577]
PeriodicSite: H (3.691, 2.022, 0.1575) [0.7139, 0.3911, 0.03046]
PeriodicSite: H (1.858, 1.785, 5.169) [0.3594, 0.3452, 0.9998]
PeriodicSite: H (0.7801, 0.07337, 0.9159) [0.1509, 0.01419, 0.1772]
PeriodicSite: H (3.279, 3.464, 5.155) [0.6342, 0.67, 0.9971],
'47b9a869_9b1e_438b_8c93_f5ac654bfdd8': Structure Summary
Lattice
abc : 4.38880300521851 4.401448726654057 7.175978183746338
angles : 90.000002504478 90.00000250447802 119.90499541135148
volume : 120.16235834626013
A : 4.388803005218506 0.0 -1.9184066957222967e-07
B : -2.1944008875261014 3.8154102843847926 -1.9239342918808688e-07
C : 0.0 0.0 7.175978183746338
pbc : True True True
PeriodicSite: O (2.194, 2.549, 3.886) [0.8341, 0.6681, 0.5415]
PeriodicSite: O (2.041e-07, 1.266, 0.2978) [0.1659, 0.3319, 0.0415]
PeriodicSite: O (2.194, 2.544, 6.579) [0.8334, 0.6667, 0.9168]
PeriodicSite: O (2.05e-07, 1.272, 2.991) [0.1666, 0.3333, 0.4168]
PeriodicSite: H (-0.1279, 3.659, 6.497) [0.4504, 0.959, 0.9053]
PeriodicSite: H (1.932, 1.894, 6.591) [0.6884, 0.4964, 0.9184]
PeriodicSite: H (2.693, 0.4954, 5.701) [0.6785, 0.1299, 0.7945]
PeriodicSite: H (0.9313, 1.986, 7.078) [0.4724, 0.5205, 0.9864]
PeriodicSite: H (3.617, 0.776, 0.1137) [0.9259, 0.2034, 0.01585]
PeriodicSite: H (-1.225, 2.952, 6.751) [0.1077, 0.7737, 0.9407]
PeriodicSite: H (1.877, 2.83, 4.877) [0.7984, 0.7417, 0.6797]
PeriodicSite: H (0.1163, 1.131, 6.242) [0.1747, 0.2965, 0.8698],
'662c7351_ee76_48ea_bab7_b733e1fdf607': Structure Summary
Lattice
abc : 4.167909392982183 8.414297301182946 8.41429615020752
angles : 86.48328861312336 75.66042875469425 75.66042198607414
volume : 276.98893913632514
A : 4.038057804107666 0.0 1.0322586297988892
B : 2.0190288528503957 8.152148872105393 0.5161301493644714
C : 0.0 0.0 8.41429615020752
pbc : True True True
PeriodicSite: O (0.6612, 1.915, 6.606) [0.04627, 0.2349, 0.7651]
PeriodicSite: O (4.326, 5.991, 3.336) [0.7037, 0.7349, 0.2651]
PeriodicSite: O (1.732, 6.237, 2.419) [0.04627, 0.7651, 0.2349]
PeriodicSite: O (3.377, 2.161, 7.047) [0.7037, 0.2651, 0.7349]
PeriodicSite: O (3.377, 1.915, 2.84) [0.7188, 0.2349, 0.2349]
PeriodicSite: O (3.751, 5.991, 7.143) [0.5613, 0.7349, 0.7349]
PeriodicSite: O (2.307, 6.237, 7.027) [0.1887, 0.7651, 0.7651]
PeriodicSite: O (0.6612, 2.161, 2.399) [0.03121, 0.2651, 0.2651]
PeriodicSite: H (1.069, 2.064, 4.164) [0.1381, 0.2532, 0.4623]
PeriodicSite: H (4.283, 6.41, 8.256) [0.6675, 0.7863, 0.8511]
PeriodicSite: H (3.153, 6.595, 6.226) [0.3762, 0.809, 0.6441]
PeriodicSite: H (1.228, 3.129, 7.099) [0.1121, 0.3838, 0.8064]
PeriodicSite: H (0.9039, 0.5538, 2.695) [0.1899, 0.06794, 0.2928]
PeriodicSite: H (1.622, 3.405, 1.922) [0.193, 0.4177, 0.1791]
PeriodicSite: H (3.576, 3.679, 2.863) [0.66, 0.4513, 0.2316]
PeriodicSite: H (3.334, 2.291, 5.648) [0.6851, 0.2811, 0.5699],
'7fa282c5_4971_46f4_8b3b_776595a0fa06': Structure Summary
Lattice
abc : 6.332112789154059 6.332114696502691 6.603632926940918
angles : 90.00000250447809 90.00000250447808 119.99999014147143
volume : 229.303674421453
A : 6.332112789154053 0.0 -2.767854425655969e-07
B : -3.1660564046927258 5.483772731611829 -2.767855278307252e-07
C : 0.0 0.0 6.603632926940918
pbc : True True True
PeriodicSite: Y (2.134, 0.0, 4.953) [0.3369, 0.0, 0.75]
PeriodicSite: Y (1.067, 1.848, 1.651) [0.3369, 0.3369, 0.25]
PeriodicSite: Y (-2.099, 3.636, 1.651) [6.806e-10, 0.6631, 0.25]
PeriodicSite: Y (-1.067, 1.848, 4.953) [0.0, 0.3369, 0.75]
PeriodicSite: Y (4.199, 4.059e-09, 1.651) [0.6631, 7.402e-10, 0.25]
PeriodicSite: Y (2.099, 3.636, 4.953) [0.6631, 0.6631, 0.75]
PeriodicSite: H (3.15, 1.752, 4.848) [0.6572, 0.3195, 0.7341]
PeriodicSite: H (0.4701, 2.696, 6.053) [0.3201, 0.4917, 0.9166]
PeriodicSite: H (-0.03274, 0.1909, 1.806) [0.01223, 0.0348, 0.2735]
PeriodicSite: H (-1.976, 3.817, 5.611) [0.03597, 0.696, 0.8497]
PeriodicSite: H (4.78, 0.1132, 6.239) [0.7653, 0.02065, 0.9448]
PeriodicSite: H (2.249, 3.828, 2.821) [0.7041, 0.698, 0.4272]
PeriodicSite: H (3.885, 1.989, 2.709) [0.7948, 0.3627, 0.4102]
PeriodicSite: H (0.893, 1.913, 4.326) [0.3155, 0.3489, 0.6551]
PeriodicSite: H (1.422, 2.137, 3.569) [0.4194, 0.3897, 0.5404]
PeriodicSite: H (6.255, 0.03513, 4.606) [0.991, 0.006407, 0.6975]
PeriodicSite: H (0.8428, 5.394, 3.953) [0.6249, 0.9835, 0.5987]
PeriodicSite: H (-1.031, 3.605, 3.669) [0.1658, 0.6573, 0.5556]
PeriodicSite: H (-0.01093, 3.839, 1.495) [0.3483, 0.7001, 0.2264]
PeriodicSite: H (2.11, 0.172, 2.377) [0.349, 0.03137, 0.36]
PeriodicSite: H (3.165, 2.702, 6.593) [0.7462, 0.4928, 0.9983]
PeriodicSite: H (2.485, 0.8155, 0.3001) [0.4668, 0.1487, 0.04545]
PeriodicSite: H (5.153, 2.037, 0.7442) [0.9995, 0.3715, 0.1127]
PeriodicSite: H (2.063, 3.722, 0.8681) [0.6652, 0.6787, 0.1315],
'c436bbf4_9aef_44a8_8960_00227f79a32f': Structure Summary
Lattice
abc : 4.403401374816899 4.962053775787357 5.689428329467773
angles : 90.00000250447805 90.00000250447815 90.00000250447816
volume : 124.31352208745227
A : 4.4034013748168945 0.0 -1.924787937923611e-07
B : -2.1689827722570944e-07 4.962053775787348 -2.1689825757675862e-07
C : 0.0 0.0 5.689428329467773
pbc : True True True
PeriodicSite: V (2.02, 2.481, 1.451) [0.4587, 0.5, 0.2551]
PeriodicSite: V (2.02, 4.962, 4.238) [0.4587, 1.0, 0.7449]
PeriodicSite: V (2.02, 2.481, 4.238) [0.4587, 0.5, 0.7449]
PeriodicSite: V (2.02, 2.13e-06, 1.451) [0.4587, 4.293e-07, 0.2551]
PeriodicSite: V (4.251, 2.481, 2.845) [0.9654, 0.5, 0.5]
PeriodicSite: V (4.251, 0.0, 2.845) [0.9654, 0.0, 0.5]
PeriodicSite: V (4.243, 2.481, -2.939e-07) [0.9636, 0.5, 0.0]
PeriodicSite: V (4.243, 0.0, -1.855e-07) [0.9636, 0.0, 0.0]
PeriodicSite: H (3.983, 0.839, 0.9776) [0.9046, 0.1691, 0.1718]
PeriodicSite: H (0.5064, 2.343, 1.471) [0.115, 0.4722, 0.2585]
PeriodicSite: H (1.565, 4.333, 3.258) [0.3554, 0.8732, 0.5727]
PeriodicSite: H (3.724, 4.335, 4.633) [0.8457, 0.8736, 0.8142]
PeriodicSite: H (2.436, 1.522, 3.406) [0.5532, 0.3067, 0.5986]
PeriodicSite: H (1.776, 3.419, 2.246) [0.4032, 0.6889, 0.3947]}
Relax structures¶
In [15]:
Copied!
from xtalpaint.utils.relaxation_utils import relax_structures
from xtalpaint.utils.relaxation_utils import relax_structures
In [16]:
Copied!
relax_kwargs = inputs.relax_kwargs.model_dump()
print(json.dumps(relax_kwargs, indent=4))
structure_labels, inpainted_structures = map(
list,
zip(
*inpainting_outputs["structures"].get_structures(strct_type="pymatgen").items()
),
)
if inputs.relax:
constrained_relaxation_outputs = relax_structures(
structures=inpainted_structures,
**relax_kwargs,
)
if inputs.full_relax:
relax_kwargs.pop('elements_to_relax', None)
full_relaxation_outputs = relax_structures(
structures=inpainted_structures,
**relax_kwargs,
)
constrained_relaxation_structures = dict(
zip(
structure_labels, constrained_relaxation_outputs[0]
)
)
full_relaxation_structures = dict(
zip(
structure_labels, full_relaxation_outputs[0]
)
)
relax_kwargs = inputs.relax_kwargs.model_dump()
print(json.dumps(relax_kwargs, indent=4))
structure_labels, inpainted_structures = map(
list,
zip(
*inpainting_outputs["structures"].get_structures(strct_type="pymatgen").items()
),
)
if inputs.relax:
constrained_relaxation_outputs = relax_structures(
structures=inpainted_structures,
**relax_kwargs,
)
if inputs.full_relax:
relax_kwargs.pop('elements_to_relax', None)
full_relaxation_outputs = relax_structures(
structures=inpainted_structures,
**relax_kwargs,
)
constrained_relaxation_structures = dict(
zip(
structure_labels, constrained_relaxation_outputs[0]
)
)
full_relaxation_structures = dict(
zip(
structure_labels, full_relaxation_outputs[0]
)
)
{
"load_path": "MatterSim-v1.0.0-5M.pth",
"fmax": 0.01,
"elements_to_relax": [
"H"
],
"max_natoms_per_batch": 5000,
"max_n_steps": 50,
"device": "cuda",
"filter": null,
"optimizer": "BFGS",
"mlip": "mattersim",
"return_initial_energies": false,
"return_initial_forces": false,
"return_final_forces": false
}
2026-01-07 15:38:28.324 | INFO | mattersim.forcefield.potential:from_checkpoint:891 - Loading the pre-trained mattersim-v1.0.0-5M.pth model
0%| | 0/5 [00:00<?, ?it/s]
/home/reents_t/.aiida_venvs/dev-mattergen-inpainting/lib/python3.10/site-packages/mattersim/applications/batch_relax.py:80: FutureWarning: Please use atoms.calc = calc atoms.set_calculator(DummyBatchCalculator())
100%|██████████| 5/5 [00:33<00:00, 6.73s/it] 2026-01-07 15:39:02.189 | INFO | mattersim.forcefield.potential:from_checkpoint:891 - Loading the pre-trained mattersim-v1.0.0-5M.pth model 100%|██████████| 5/5 [00:49<00:00, 9.90s/it]
Evaluate the inpainted structures with repspect to the initial reference¶
In [17]:
Copied!
from xtalpaint.eval import evaluate_inpainting
import pandas as pd
from xtalpaint.eval import evaluate_inpainting
import pandas as pd
In [18]:
Copied!
rmsd_inpainted_structures = evaluate_inpainting(
inpainted_structures=inpainting_outputs["structures"],
reference_structures=input_structures,
metric="rmsd",
max_workers=3,
normalization_element='H',
)
matches_inpainted_structures = evaluate_inpainting(
inpainted_structures=inpainting_outputs["structures"],
reference_structures=input_structures,
metric="match",
max_workers=3,
)
inpainted_evaluation = pd.merge(
rmsd_inpainted_structures, matches_inpainted_structures, left_index=True, right_index=True
)
rmsd_constrained_relaxation = evaluate_inpainting(
inpainted_structures=constrained_relaxation_structures,
reference_structures=input_structures,
metric="rmsd",
max_workers=3,
normalization_element='H',
)
matches_constrained_relaxation = evaluate_inpainting(
inpainted_structures=constrained_relaxation_structures,
reference_structures=input_structures,
metric="match",
max_workers=3,
)
constrained_relaxation_evaluation = pd.merge(
rmsd_constrained_relaxation, matches_constrained_relaxation, left_index=True, right_index=True
)
rmsd_inpainted_structures = evaluate_inpainting(
inpainted_structures=inpainting_outputs["structures"],
reference_structures=input_structures,
metric="rmsd",
max_workers=3,
normalization_element='H',
)
matches_inpainted_structures = evaluate_inpainting(
inpainted_structures=inpainting_outputs["structures"],
reference_structures=input_structures,
metric="match",
max_workers=3,
)
inpainted_evaluation = pd.merge(
rmsd_inpainted_structures, matches_inpainted_structures, left_index=True, right_index=True
)
rmsd_constrained_relaxation = evaluate_inpainting(
inpainted_structures=constrained_relaxation_structures,
reference_structures=input_structures,
metric="rmsd",
max_workers=3,
normalization_element='H',
)
matches_constrained_relaxation = evaluate_inpainting(
inpainted_structures=constrained_relaxation_structures,
reference_structures=input_structures,
metric="match",
max_workers=3,
)
constrained_relaxation_evaluation = pd.merge(
rmsd_constrained_relaxation, matches_constrained_relaxation, left_index=True, right_index=True
)
80%|████████ | 4/5 [00:00<00:00, 13.49it/s] 80%|████████ | 4/5 [00:00<00:00, 27.33it/s] 80%|████████ | 4/5 [00:00<00:00, 15.04it/s] 80%|████████ | 4/5 [00:00<00:00, 31.99it/s]
In [19]:
Copied!
inpainted_evaluation
inpainted_evaluation
Out[19]:
| rmsd | match | |
|---|---|---|
| keys | ||
| 20acc66e_8e38_4e5e_9e7a_c2400262cdc8 | 1.184509 | False |
| 47b9a869_9b1e_438b_8c93_f5ac654bfdd8 | 1.319913 | False |
| 662c7351_ee76_48ea_bab7_b733e1fdf607 | 1.804719 | False |
| 7fa282c5_4971_46f4_8b3b_776595a0fa06 | 0.813573 | False |
| c436bbf4_9aef_44a8_8960_00227f79a32f | 1.317161 | False |
In [20]:
Copied!
constrained_relaxation_evaluation
constrained_relaxation_evaluation
Out[20]:
| rmsd | match | |
|---|---|---|
| keys | ||
| 20acc66e_8e38_4e5e_9e7a_c2400262cdc8 | 0.783026 | False |
| 47b9a869_9b1e_438b_8c93_f5ac654bfdd8 | 0.326994 | True |
| 662c7351_ee76_48ea_bab7_b733e1fdf607 | 2.046668 | False |
| 7fa282c5_4971_46f4_8b3b_776595a0fa06 | 0.017497 | True |
| c436bbf4_9aef_44a8_8960_00227f79a32f | 1.001336 | False |
In [ ]:
Copied!